diff --git a/paddle/fluid/operators/controlflow/conditional_block_op.cc b/paddle/fluid/operators/controlflow/conditional_block_op.cc index a1ef201912dd92a33705c4cd0d5c01639d9fe495..60e5aa61b99d5404088b0bd798850bb5d547e1a3 100644 --- a/paddle/fluid/operators/controlflow/conditional_block_op.cc +++ b/paddle/fluid/operators/controlflow/conditional_block_op.cc @@ -16,10 +16,8 @@ limitations under the License. */ #include #include "paddle/fluid/framework/new_executor/standalone_executor.h" -#include "paddle/fluid/operators/assign_op.h" #include "paddle/fluid/operators/controlflow/control_flow_op_helper.h" #include "paddle/phi/core/flags.h" -#include "paddle/phi/kernels/funcs/math_function.h" #ifdef PADDLE_WITH_DNNL #include "paddle/fluid/platform/mkldnn_helper.h" @@ -226,129 +224,6 @@ class ConditionalBlockGradOp : public ConditionalOp { mutable std::shared_ptr exec_{nullptr}; mutable std::unique_ptr ctx_{nullptr}; mutable std::shared_ptr core_{nullptr}; - - private: - void AssignLocalGradientToParentScope( - const platform::Place &place, - const framework::Scope &cur_scope, - const framework::Scope &parent_scope, - const std::vector &inside_grads, - const std::vector &outside_grads, - const std::vector &inputs) const { - std::vector assign_zero_outside_grads; - std::vector assign_zero_inputs; - for (size_t i = 0; i < outside_grads.size(); ++i) { - const std::string &outside_grad_name = outside_grads[i]; - const std::string &inside_grad_name = inside_grads[i]; - VLOG(4) << "[assign local]" - << "inside_grad_name = " << inside_grad_name - << ", outside_grad_name = " << outside_grad_name; - framework::Variable *outside_var = - parent_scope.FindVar(outside_grad_name); - if (outside_var == nullptr) { - continue; - } - framework::Variable *inside_var = - cur_scope.FindLocalVar(inside_grad_name); - if (inside_var == nullptr) { - assign_zero_outside_grads.emplace_back(outside_grad_name); - assign_zero_inputs.emplace_back(inputs[i]); - continue; - } - platform::DeviceContext *dev_ctx = - platform::DeviceContextPool::Instance().Get(place); - framework::VisitVarType(*inside_var, - AssignFunctor(outside_var, *dev_ctx)); - } - // Assign zero to the grad_vars that are in outside_grads but not in - // inside_grads - AssignZeroToParentScope( - place, parent_scope, assign_zero_inputs, assign_zero_outside_grads); - } - - void AssignZeroToParentScope( - const platform::Place &place, - const framework::Scope &scope, - const std::vector &inputs, - const std::vector &outside_grads) const { - for (size_t i = 0; i < outside_grads.size(); ++i) { - const std::string &outside_grad_name = outside_grads[i]; - const std::string &input_name = inputs[i]; - VLOG(4) << "[assign zero]" - << "input_name = " << input_name - << ", outside_grad_name = " << outside_grad_name; - framework::Variable *input_var = scope.FindVar(input_name); - if (input_var == nullptr) { - continue; - } - framework::Variable *outside_var = scope.FindVar(outside_grad_name); - if (outside_var == nullptr) { - continue; - } - - if (input_var->IsType()) { - PADDLE_ENFORCE_EQ( - outside_var->IsType(), - true, - platform::errors::InvalidArgument( - "Type of outside_var %s is NOT phi::DenseTensor, which " - "doesn't match input_var %s.", - outside_grad_name, - input_name)); - AssignZeroToOutsideTensor(place, - scope, - input_var->Get(), - outside_var->GetMutable()); - } else if (input_var->IsType()) { - PADDLE_ENFORCE_EQ(outside_var->IsType(), - true, - platform::errors::InvalidArgument( - "Type of outside_var %s is NOT LoDTensorArray, " - "which doesn't match input_var %s.", - outside_grad_name, - input_name)); - const auto &input_tensors = input_var->Get(); - auto *outside_tensors = - outside_var->GetMutable(); - if (outside_tensors->empty()) { - outside_tensors->resize(input_tensors.size()); - } - PADDLE_ENFORCE_EQ(input_tensors.size(), - outside_tensors->size(), - platform::errors::InvalidArgument( - "LoDTensorArray outside_var %s doen't have same " - "size as input_var %s.", - outside_grad_name, - input_name)); - for (size_t j = 0; j < input_tensors.size(); ++j) { - AssignZeroToOutsideTensor( - place, scope, input_tensors[j], &((*outside_tensors)[j])); - } - } else { - // TODO(huihuangzheng): add support for SelectedRows - PADDLE_THROW(platform::errors::InvalidArgument( - "Conditional block grad op doesn't support non-phi::DenseTensor " - "output " - "now.")); - } - } - } - - void AssignZeroToOutsideTensor(const platform::Place &place, - const framework::Scope &cur_scope, - const phi::DenseTensor &input_tensor, - phi::DenseTensor *outside_tensor) const { - if (!input_tensor.IsInitialized() || input_tensor.numel() == 0) { - return; - } - VLOG(4) << "Assigning zero to " << outside_tensor; - outside_tensor->Resize(input_tensor.dims()); - outside_tensor->mutable_data(place, input_tensor.dtype()); - const platform::DeviceContext *dev_ctx = - platform::DeviceContextPool::Instance().Get(place); - phi::funcs::set_constant(*dev_ctx, outside_tensor, 0.0f); - outside_tensor->set_lod(input_tensor.lod()); - } }; template diff --git a/paddle/fluid/operators/controlflow/control_flow_op_helper.h b/paddle/fluid/operators/controlflow/control_flow_op_helper.h index 82b57831f935618305fdfe5a2e0b52ae7725c7bf..0d08ae6d686630417f73933b88d5312addd4b773 100644 --- a/paddle/fluid/operators/controlflow/control_flow_op_helper.h +++ b/paddle/fluid/operators/controlflow/control_flow_op_helper.h @@ -15,6 +15,8 @@ #pragma once #include "paddle/fluid/framework/new_executor/standalone_executor.h" +#include "paddle/fluid/operators/assign_op.h" +#include "paddle/phi/kernels/funcs/math_function.h" namespace paddle { namespace operators { @@ -54,5 +56,123 @@ static void BuildScopeForControlFlowOp( } } +static void AssignZeroToOutsideTensor(const platform::Place &place, + const framework::Scope &cur_scope, + const phi::DenseTensor &input_tensor, + phi::DenseTensor *outside_tensor) { + if (!input_tensor.IsInitialized() || input_tensor.numel() == 0) { + return; + } + VLOG(4) << "Assigning zero to " << outside_tensor; + outside_tensor->Resize(input_tensor.dims()); + outside_tensor->mutable_data(place, input_tensor.dtype()); + const platform::DeviceContext *dev_ctx = + platform::DeviceContextPool::Instance().Get(place); + phi::funcs::set_constant(*dev_ctx, outside_tensor, 0.0f); + outside_tensor->set_lod(input_tensor.lod()); +} + +static void AssignZeroToParentScope( + const platform::Place &place, + const framework::Scope &scope, + const std::vector &inputs, + const std::vector &outside_grads) { + for (size_t i = 0; i < outside_grads.size(); ++i) { + const std::string &outside_grad_name = outside_grads[i]; + const std::string &input_name = inputs[i]; + VLOG(4) << "[assign zero]" + << "input_name = " << input_name + << ", outside_grad_name = " << outside_grad_name; + framework::Variable *input_var = scope.FindVar(input_name); + if (input_var == nullptr) { + continue; + } + framework::Variable *outside_var = scope.FindVar(outside_grad_name); + if (outside_var == nullptr) { + continue; + } + + if (input_var->IsType()) { + PADDLE_ENFORCE_EQ( + outside_var->IsType(), + true, + platform::errors::InvalidArgument( + "Type of outside_var %s is NOT phi::DenseTensor, which " + "doesn't match input_var %s.", + outside_grad_name, + input_name)); + AssignZeroToOutsideTensor(place, + scope, + input_var->Get(), + outside_var->GetMutable()); + } else if (input_var->IsType()) { + PADDLE_ENFORCE_EQ(outside_var->IsType(), + true, + platform::errors::InvalidArgument( + "Type of outside_var %s is NOT LoDTensorArray, " + "which doesn't match input_var %s.", + outside_grad_name, + input_name)); + const auto &input_tensors = input_var->Get(); + auto *outside_tensors = + outside_var->GetMutable(); + if (outside_tensors->empty()) { + outside_tensors->resize(input_tensors.size()); + } + PADDLE_ENFORCE_EQ(input_tensors.size(), + outside_tensors->size(), + platform::errors::InvalidArgument( + "LoDTensorArray outside_var %s doen't have same " + "size as input_var %s.", + outside_grad_name, + input_name)); + for (size_t j = 0; j < input_tensors.size(); ++j) { + AssignZeroToOutsideTensor( + place, scope, input_tensors[j], &((*outside_tensors)[j])); + } + } else { + // TODO(huihuangzheng): add support for SelectedRows + PADDLE_THROW(platform::errors::InvalidArgument( + "Conditional block grad op doesn't support non-phi::DenseTensor " + "output " + "now.")); + } + } +} + +static void AssignLocalGradientToParentScope( + const platform::Place &place, + const framework::Scope &cur_scope, + const framework::Scope &parent_scope, + const std::vector &inside_grads, + const std::vector &outside_grads, + const std::vector &inputs) { + std::vector assign_zero_outside_grads; + std::vector assign_zero_inputs; + for (size_t i = 0; i < outside_grads.size(); ++i) { + const std::string &outside_grad_name = outside_grads[i]; + const std::string &inside_grad_name = inside_grads[i]; + VLOG(4) << "[assign local]" + << "inside_grad_name = " << inside_grad_name + << ", outside_grad_name = " << outside_grad_name; + framework::Variable *outside_var = parent_scope.FindVar(outside_grad_name); + if (outside_var == nullptr) { + continue; + } + framework::Variable *inside_var = cur_scope.FindLocalVar(inside_grad_name); + if (inside_var == nullptr) { + assign_zero_outside_grads.emplace_back(outside_grad_name); + assign_zero_inputs.emplace_back(inputs[i]); + continue; + } + platform::DeviceContext *dev_ctx = + platform::DeviceContextPool::Instance().Get(place); + framework::VisitVarType(*inside_var, AssignFunctor(outside_var, *dev_ctx)); + } + // Assign zero to the grad_vars that are in outside_grads but not in + // inside_grads + AssignZeroToParentScope( + place, parent_scope, assign_zero_inputs, assign_zero_outside_grads); +} } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/controlflow/pylayer_op.cc b/paddle/fluid/operators/controlflow/pylayer_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..eef62289d76f580390eecd16e3a8a850b5f953d5 --- /dev/null +++ b/paddle/fluid/operators/controlflow/pylayer_op.cc @@ -0,0 +1,255 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/controlflow/pylayer_op.h" + +#include "paddle/fluid/operators/assign_op.h" +#include "paddle/fluid/operators/controlflow/control_flow_op_helper.h" +#include "paddle/phi/core/flags.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace paddle { +namespace operators { + +namespace { // NOLINT +enum class PyLayerBlockIndex { kFORWARD = 0, kBACKWARD = 1, kNONE = 2 }; +} // namespace + +const char PyLayerOp::kInputs[] = "Input"; +const char PyLayerOp::kOutputs[] = "Out"; +const char PyLayerOp::kScope[] = "Scope"; +const char PyLayerOp::kSkipEagerDeletionVars[] = "skip_eager_deletion_vars"; +const char PyLayerOp::kBlocks[] = "blocks"; + +void PyLayerOp::CreateInterpreter( + const platform::Place &dev_place, + const framework::BlockDesc &block, + framework::Scope *cur_scope, + const std::vector &skip_vars) const { + if (!core_ || !platform::is_same_place(core_->GetPlace(), dev_place)) { + VLOG(10) << "[interpreterCore cache]" << core_.get(); + VLOG_IF(10, core_) << platform::is_same_place(core_->GetPlace(), dev_place); + + framework::interpreter::ExecutionConfig execution_config; + execution_config.create_local_scope = false; + execution_config.used_for_control_flow_op = true; + execution_config.skip_gc_vars = + std::set(skip_vars.begin(), skip_vars.end()); + + core_.reset(new framework::InterpreterCore( + dev_place, block, cur_scope, execution_config)); + VLOG(10) << "[interpreterCore] created:" << core_; + } else { + // NOTE: Borrowed from + // `paddle/fluid/operators/controlflow/control_flow_op_helper.h` + // TODO(MarioLulab): Add PyLayer Helper ? + BuildScopeForControlFlowOp(*core_, block, cur_scope); + core_->reset_scope(cur_scope); + } +} + +class PyLayerForwardOp : public PyLayerOp { + public: + PyLayerForwardOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : PyLayerOp(type, inputs, outputs, attrs) {} + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const { + auto *scope_var = scope.FindVar(Output(kScope)); + PADDLE_ENFORCE_NOT_NULL( + scope_var, + platform::errors::PreconditionNotMet( + "Expect Scope variable to be set in pylayer_op, but " + "got a null Scope variable. Please set the Scope variable.")); + + auto *scopes = scope_var->GetMutable>(); + scopes->resize(1); + scopes->front() = &scope.NewScope(); + + auto &cur_scope = *scopes->front(); + auto &blocks = + Attr>(PyLayerOp::kBlocks); + PADDLE_ENFORCE_GT( + blocks.size(), + 0, + platform::errors::InvalidArgument( + "Expect blocks contains at least 1 block, but got: %d", + blocks.size())); + + framework::BlockDesc *forward_block = + blocks[static_cast(PyLayerBlockIndex::kFORWARD)]; + VLOG(3) << "PyLayer forward_block block.idx = " << forward_block->ID() + << ", scope = " << &cur_scope; + + auto &skip_vars = Attr>(kSkipEagerDeletionVars); + + LOG_FIRST_N(INFO, 1) << "[ControlFlow][PyLayer] New Executor is Running."; + + CreateInterpreter(dev_place, *forward_block, &cur_scope, skip_vars); + PADDLE_ENFORCE_NOT_NULL(core_, platform::errors::Fatal("core_ is nullptr")); + core_->Run({}, false); + } +}; + +class PyLayerForwardInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *context) const override { + // TODO(MarioLulab): do nothing. + } +}; + +template +class PyLayerBackwardMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr grad_op) const override { + grad_op->SetType("pylayer_grad"); + grad_op->SetInput(PyLayerOp::kInputs, this->Input(PyLayerOp::kInputs)); + grad_op->SetInput(framework::GradVarName(PyLayerOp::kOutputs), + this->OutputGrad(PyLayerOp::kOutputs)); + grad_op->SetInput(PyLayerOp::kScope, this->Output(PyLayerOp::kScope)); + + auto fwd_inputs = this->InputGrad(PyLayerOp::kInputs, false); + grad_op->SetOutput(framework::GradVarName(PyLayerOp::kInputs), fwd_inputs); + + const std::vector &blocks = PADDLE_GET_CONST( + std::vector, this->GetAttr(PyLayerOp::kBlocks)); + PADDLE_ENFORCE_GT( + blocks.size(), + static_cast(PyLayerBlockIndex::kBACKWARD), + platform::errors::InvalidArgument( + "Expect blocks contains at least 2 block, but got: %d", + blocks.size())); + grad_op->SetBlockAttr( + "backward_block", + blocks[static_cast(PyLayerBlockIndex::kBACKWARD)]); + } +}; + +class PyLayerBackwardOp : public PyLayerOp { + public: + PyLayerBackwardOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : PyLayerOp(type, inputs, outputs, attrs) {} + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { + const auto &inputs = Inputs(PyLayerOp::kInputs); + const auto &outside_grads = + Outputs(framework::GradVarName(PyLayerOp::kInputs)); + std::vector inside_grads; + inside_grads.reserve(inputs.size()); + for (auto &in : inputs) { + inside_grads.emplace_back(framework::GradVarName(in)); + } + + PADDLE_ENFORCE_EQ( + inside_grads.size(), + outside_grads.size(), + platform::errors::InvalidArgument( + "Mismatch inside_grads.size(): %d, and outside_grads.size(): %d", + inside_grads.size(), + outside_grads.size())); + + auto *scope_var = scope.FindVar(Input(PyLayerOp::kScope)); + PADDLE_ENFORCE_NOT_NULL( + scope_var, + platform::errors::PreconditionNotMet( + "Expect Scope variable to be set in pylayer_op, but " + "got a null Scope variable. Please set the Scope variable.")); + auto &scopes = scope_var->Get>(); + PADDLE_ENFORCE_GT( + scopes.size(), + 0, + platform::errors::InvalidArgument( + "Expect Scope variable contains at least 1 scope, but got: %d", + scopes.size())); + framework::Scope &cur_scope = *(scopes[0]); + + auto *backward_block = Attr("backward_block"); + VLOG(3) << "Static PyLayer backward block.idx = " << backward_block->ID() + << ", scope = " << &cur_scope; + + LOG_FIRST_N(INFO, 1) + << "[ControlFlow][PyLayerBackwardOp] New Executor is Running."; + + CreateInterpreter(dev_place, *backward_block, &cur_scope, inside_grads); + PADDLE_ENFORCE_NOT_NULL(core_, platform::errors::Fatal("core_ is nullptr")); + + core_->Run({}, false); + + // NOTE: It's neccessary. The reason of associating `inside_grads` and + // `outside_grads` at runtime `RunImpl` instead of `assgin` op at block is + // that the Var name of grad_op's outputs may be changed in the + // `append_backward` function (e.g. `_addup_repetitive_outputs_`). + AssignLocalGradientToParentScope( + dev_place, cur_scope, scope, inside_grads, outside_grads, inputs); + + // Release the cur_scope, otherwise memory leakage occurs. + scope.DeleteScope(&cur_scope); + return; + } +}; + +class PyLayerBackwardInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *context) const override { + if (context->HasInputs(PyLayerOp::kInputs) && + context->HasOutputs(framework::GradVarName(PyLayerOp::kInputs))) { + context->SetOutputsDim(framework::GradVarName(PyLayerOp::kInputs), + context->GetInputsDim(PyLayerOp::kInputs)); + } + } +}; + +class PyLayerBackwardInferVarType : public framework::VarTypeInference { + public: + void operator()(framework::InferVarTypeContext *ctx) const override { + auto forward_input_size = ctx->InputSize(PyLayerOp::kInputs); + auto backward_output_size = + ctx->OutputSize(framework::GradVarName(PyLayerOp::kInputs)); + PADDLE_ENFORCE_EQ(forward_input_size, + backward_output_size, + platform::errors::InvalidArgument( + "input_size and output_size should be equal for " + "pylayer_grad op.")); + for (size_t i = 0; i < backward_output_size; ++i) { + ctx->SyncTypeAndDataType( + PyLayerOp::kInputs, framework::GradVarName(PyLayerOp::kInputs), i); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(pylayer, + ops::PyLayerForwardOp, + ops::PyLayerForwardInferShape, + ops::PyLayerForwardOpProtoMaker, + ops::PyLayerBackwardMaker); +REGISTER_OPERATOR(pylayer_grad, + ops::PyLayerBackwardOp, + ops::PyLayerBackwardInferShape, + ops::PyLayerBackwardInferVarType); diff --git a/paddle/fluid/operators/controlflow/pylayer_op.h b/paddle/fluid/operators/controlflow/pylayer_op.h new file mode 100644 index 0000000000000000000000000000000000000000..e06daad78041d9967d342b2268edf6d4dcc2a437 --- /dev/null +++ b/paddle/fluid/operators/controlflow/pylayer_op.h @@ -0,0 +1,75 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include + +#include "paddle/fluid/framework/new_executor/interpretercore.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/var_type.h" + +namespace paddle { +namespace operators { + +class PyLayerOp : public framework::OperatorBase { + public: + PyLayerOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : framework::OperatorBase(type, inputs, outputs, attrs) {} + + static const char kInputs[]; + static const char kOutputs[]; + static const char kScope[]; + static const char kSkipEagerDeletionVars[]; + static const char kBlocks[]; + + protected: + void CreateInterpreter(const platform::Place &dev_place, + const framework::BlockDesc &block, + framework::Scope *scope, + const std::vector &skip_vars) const; + + protected: + mutable std::shared_ptr core_{nullptr}; +}; + +class PyLayerForwardOpProtoMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput(PyLayerOp::kInputs, "The input variables of the sub-block.") + .AsDuplicable(); + AddOutput(PyLayerOp::kOutputs, "The output variables of the sub-block.") + .AsDuplicable(); + // TODO(MarioLulab): Must Use std::vector here ? + AddOutput(PyLayerOp::kScope, + "(std::vector) The scope of static pylayer block."); + AddAttr>( + "blocks", "The blocks of PyLayer operator"); + AddComment(R"DOC(PyLayer operator + +TO-DO: added by luqi + + +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 0047ffce5683914cb8bccb79fd2ab3039ac33e09..688108128242aa31ef1e2c281fca37371261683a 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -2165,6 +2165,11 @@ attrs : {axis : Axis, reduce : Reduce} +- op : pylayer + backward : pylayer_grad + extra : + attrs : ['str[] skip_eager_deletion_vars = {}'] + - op : qr backward : qr_grad inputs : diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 35e8dba75488781df70462efad80a7d56439b921..5a43b7f930200b087f0a66af522633f38959538d 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -2716,6 +2716,7 @@ class Operator: 'go', 'rnn_memory_helper_grad', 'conditional_block', + 'pylayer', 'while', 'send', 'recv', @@ -4250,6 +4251,8 @@ class Block: ignore_ops = { 'conditional_block', 'conditional_block_grad', + 'pylayer', + 'pylayer_grad', 'recurrent', 'recurrent_grad', 'while', diff --git a/python/paddle/static/nn/__init__.py b/python/paddle/static/nn/__init__.py index 8397c16db45a03d317d31f908356e0f3018fb1fe..d144f87ec32cb918b7094753cd3d6301cce10ca7 100755 --- a/python/paddle/static/nn/__init__.py +++ b/python/paddle/static/nn/__init__.py @@ -58,6 +58,7 @@ from .sequence_lod import sequence_enumerate # noqa: F401 from .sequence_lod import sequence_reverse # noqa: F401 from .control_flow import cond +from .static_pylayer import static_pylayer __all__ = [ # noqa 'fc', @@ -66,6 +67,7 @@ __all__ = [ # noqa 'embedding', 'case', 'cond', + 'static_pylayer', 'conv2d', 'conv2d_transpose', 'conv3d', diff --git a/python/paddle/static/nn/static_pylayer.py b/python/paddle/static/nn/static_pylayer.py new file mode 100644 index 0000000000000000000000000000000000000000..e5fd171d32663caeac74c6605f6761f782a7562a --- /dev/null +++ b/python/paddle/static/nn/static_pylayer.py @@ -0,0 +1,332 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from paddle.common_ops_import import LayerHelper, check_type, in_dygraph_mode +from paddle.fluid import core +from paddle.fluid.backward import _append_grad_suffix_ +from paddle.fluid.framework import Variable +from paddle.utils import flatten, map_structure + +# NOTE(MarioLulab): Borrowed from `python/paddle/static/nn/control_flow.py` +from .control_flow import BlockGuard, copy_var_to_parent_block + + +class StaticPyLayerBlockGuard(BlockGuard): + def __init__(self, block_manager): + check_type( + block_manager, + "block", + StaticPyLayerBlock, + "StaticPyLayerBlockGuard", + ) + super().__init__(block_manager.helper.main_program) + self.block_manager = block_manager + + def __enter__(self): + super().__enter__() + return self.block_manager + + def __exit__(self, exc_type, exc_val, exc_tb): + self.block_manager.complete() + return super().__exit__(exc_type, exc_val, exc_tb) + + +class StaticPyLayerBlock: + def __init__(self, inputs, name=None): + for each_input in inputs: + check_type(each_input, "input", Variable, "StaticPyLayerBlock") + + # used to specify the `Input` to `pylayer` op + self.fwd_inputs = inputs + # used to specify the `Out` to `pylayer` op + self.fwd_outputs = [] + + self.helper = LayerHelper("static_pylayer_block", name=name) + self.fwd_op_id = None + self._forward_block_id = None + self._backward_block_id = None + self.var_old_to_new = {} + + def block(self, is_backward_block=False): + self.is_backward_block = is_backward_block + return StaticPyLayerBlockGuard(self) + + @property + def forward_block_index(self): + return self._forward_block_id + + @property + def backward_block_index(self): + return self._backward_block_id + + @property + def fwd_op_index(self): + return self.fwd_op_id + + def complete_forward_block(self): + inside_block = self.helper.main_program.current_block() + parent_block = self.helper.main_program.block(inside_block.parent_idx) + self._forward_block_id = inside_block.idx + + step_scope = parent_block.create_var( + type=core.VarDesc.VarType.STEP_SCOPES + ) + + pylayer_op = parent_block.append_op( + type='pylayer', + inputs={ + 'Input': self.fwd_inputs, + }, + outputs={"Out": self.fwd_outputs, "Scope": [step_scope]}, + attrs={ + 'blocks': [inside_block], + }, + ) + + self.fwd_op_id = pylayer_op.idx + + def complete_backward_block(self): + inside_block = self.helper.main_program.current_block() + parent_block = self.helper.main_program.block(inside_block.parent_idx) + + self._backward_block_id = inside_block.idx + # set OpRole to `backward` + for op in inside_block.ops: + op_role_attr_name = ( + core.op_proto_and_checker_maker.kOpRoleAttrName() + ) + backward = core.op_proto_and_checker_maker.OpRole.Backward + op.desc._set_attr(op_role_attr_name, backward) + inside_block._set_forward_block_idx(self.forward_block_index) + + # NOTE(MarioLulab): The reason of renaming the var name in the inside block is that + # we need to associating `inside_grads` and `outside_grads` at + # runtime `RunImpl` in pylayer op + for old_var_name, new_var_name in self.var_old_to_new.items(): + # TODO(MarioLulab): need to remove recursively in ``sub_block`` + + # NOTE(MarioLulab): The reason why not using Block._rename_var is that `old_var_name` does not correspond to a Variable instance in Block + # and Block._rename_var will raise ValueError. + inside_block.desc._rename_var( + old_var_name.encode(), new_var_name.encode() + ) + + # update `blocks` attr by appending backward_block + forward_block_desc = parent_block.program.block( + self.forward_block_index + ).desc + backward_block_desc = inside_block.desc + parent_block.ops[self.fwd_op_index].desc.set_blocks_attr( + "blocks", [forward_block_desc, backward_block_desc] + ) + + def complete(self): + if not self.is_backward_block: + return self.complete_forward_block() + else: + return self.complete_backward_block() + + +# TODO(MarioLulab): +# Need to support non-Variable in ``inputs`` +def static_pylayer(forward_fn, inputs, backward_fn=None, name=None): + """ + This API returns ``forward_fn(inputs)``, and two sub-block are created based on + the logic of ``forward_fn`` and ``backward_fn``, with the operator ``pylayer`` + holding information about the two blocks. + + ``forward_fn`` and ``backward_fn`` should return a nest structure of tensors. + A nest structure of tensors in PaddlePaddle is tensor(s), or tuple of tensors, or + list of tensors. + + Note: + 1. If ``backward_fn`` is not None, user needs to keep the number of inputs to ``forward_fn`` the same as the + number of outputs to ``backward_fn``, and the number of outputs to ``forward_fn`` + the same as the number of inputs to ``backward_fn``. + + 2. If ``backward_fn`` is None, ``stop_gradient`` attr of all Variable in ``inputs`` is expected to be True. + Otherwise it might get unexpected results in backward pass. + + 3. This API can only be used under static graph mode. + + Args: + forward_fn (callable): A callable to be performed in forward pass + inputs (list[Variable]): The list of if input Variable to the ``forward_fn`` + backward_fn (callable, optional): A callable to be performed in backward pass + name (str, optional): The default value is ``None`` . Normally users + don't have to set this parameter. + + Returns: + Variable|list(Variable)|tuple(Variable): returns the output of ``forward_fn(inputs)`` + + Examples: + .. code-block: python + + import paddle + import numpy as np + + # + # pseudocode: + # y = exp(x) + # dx = 2 * exp(dy) + # + + paddle.enable_static() + + def forward_fn(x): + return paddle.exp(x) + + def backward_fn(dy): + return 2 * paddle.exp(dy) + + main_program = paddle.static.Program() + start_program = paddle.static.Program() + + place = paddle.CPUPlace() + exe = paddle.static.Executor(place) + with paddle.static.program_guard(main_program, start_program): + data = paddle.static.data(name="X", shape=[None, 5], dtype="float32") + data.stop_gradient = False + ret = paddle.static.nn.static_pylayer(forward_fn, [data], backward_fn) + data_grad = paddle.static.gradients([ret], data)[0] + + exe = paddle.static.Executor(place) + exe.run(start_program) + x = np.array([[1.0, 2.0, 3.0, 4.0, 5.0]], dtype=np.float32) + x, x_grad, y = exe.run( + main_program, + feed={"X": x}, + fetch_list=[ + data.name, + data_grad.name, + ret.name + ], + ) + # x is Numpy + # x.data = [[1.0, 2.0, 3.0, 4.0, 5.0]] + # x.shape = [1, 5] + # y is Numpy + # y.data = [[2.7182817, 7.389056, 20.085537, 54.59815, 148.41316]] + # y.shape = [1, 5] + # x_grad is Numpy + # x_grad.data = [[5.4365635, 5.4365635, 5.4365635, 5.4365635, 5.4365635]] + # x_grad.shape = [1, 5] + """ + assert ( + in_dygraph_mode() is False + ), "please use PyLayer instead of static_pylayer in dygraph mode" + + assert isinstance(inputs, list) + if backward_fn is None: + for input_var in inputs: + if input_var.stop_gradient is False: + raise ValueError( + "``stop_gradient`` attr of all inputs to ``forward_fn`` are expected to be True, when ``backward_fn == None``, but {}.stop_gradient got {}".format( + input_var.name, input_var.stop_gradient + ) + ) + + check_type(name, "name", (str, type(None)), "fluid.layers.static_pylayer") + helper = LayerHelper('static_pylayer', **locals()) + copy_to_parent_func = lambda var: copy_var_to_parent_block(var, helper) + + assert forward_fn is not None and callable(forward_fn) + pylayer_block_manager = StaticPyLayerBlock(inputs) + with pylayer_block_manager.block(is_backward_block=False) as mgr: + origin_output = forward_fn(*inputs) + if origin_output is not None: + output = map_structure(copy_to_parent_func, origin_output) + mgr.fwd_outputs = flatten(output) + else: + mgr.fwd_outputs = [] + + current_block = helper.main_program.current_block() + current_block._sync_with_cpp() + if backward_fn is not None: + assert callable(backward_fn) + if origin_output is None: + output = [] + + # **Create the backward input** from the output of the op to build the + # backward block, and then delete it. + grad_var_ins = [] + for fwd_var in flatten(output): + fwd_var_name = fwd_var.name + bwd_var_name = _append_grad_suffix_(fwd_var_name) + var = current_block.create_var(name=bwd_var_name) + if not current_block.desc.has_var_recursive(fwd_var_name.encode()): + raise ValueError( + "Grad var {} , we can't find its related forward var {}".format( + bwd_var_name, fwd_var_name + ) + ) + + var.desc.set_dtype(fwd_var.dtype) + var.desc.set_shape(fwd_var.shape) + + grad_var_ins.append(var) + + assert isinstance(grad_var_ins, list) + with pylayer_block_manager.block(is_backward_block=True) as mgr: + grad_origin_output = backward_fn(*grad_var_ins) + if grad_origin_output is not None: + flat_grad_origin = flatten(grad_origin_output) + # NOTE(MarioLulab): ``current_block`` was defined outside + forward_input_names = current_block.ops[ + pylayer_block_manager.fwd_op_index + ].desc.input_arg_names() + assert len(forward_input_names) == len( + flat_grad_origin + ), f"needs to keep the number of inputs to ``forward_fn`` the same as the number of outputs to ``backward_fn``, \ + but got {len(forward_input_names)} and {len(flat_grad_origin)}" + + for bwd_output_name, fwd_input_name in zip( + flat_grad_origin, forward_input_names + ): + # NOTE(MarioLulab): Because `flat_grad_origin` are the Variables inside the backward block, which one by one corresponds + # to the gradients of the inputs to the forward function, we need to establish a link between `flat_grad_origin`, + # and the Variable outside the backward block which represent the gradient of the input ot the forward function. + # The approach we have taken is renaming `flat_grad_origin` by forward input name with suffix of "@GRAD", and aligning + # the order of `Out@GRAD` in `pylayer_grad` op with `flat_grad_origin`. And in the runtime `RunImpl` in `pylayer_grad` op, + # we will find inside_grad with the name of forward input name with suffix of "@GRAD" in the scope, and assign `inside_grads` + # to `outside_grads`. + # + # Example: + # after run the code below to create forward and backward block: + # + # out = forward_fn(x, y) # create forward block + # x_grad, y_grad = backward_fn(out_grad) # create backward block + # + # x.name is "X", y.name is "Y", and out.name is "tmp_0", but x_grad.name is "_generate_0", y_grad.name is "_generate_1". + # we rename x_grad by "X@GRAD", and y_grad by "Y@GRAD" inside backward block. + # One thing to keep in mind is that we assume there were no Variable naming "X@GRAD" inside backward block before performing rename operation. + # TODO(MarioLulab): We will validate the assumption above is whether a strong hypothesis or not. + + # attach old var name into new + bwd_out_new = _append_grad_suffix_( + fwd_input_name + ) # "X" => "X@GRAD" + mgr.var_old_to_new[ + bwd_output_name.name + ] = bwd_out_new # e.g. "tmp_0.mean_0": "X@GRAD" + + # **Delete the backward input** + for bwd_var in grad_var_ins: + current_block._remove_var(bwd_var.name) + + if origin_output is None: + return None + + return output diff --git a/test/ir/inference/program_config.py b/test/ir/inference/program_config.py index 9df3359c3cc3043faf97d9d0b2ccc057dc0767e6..3c4d82126b59a3632a890987b55669d29cd29bb1 100644 --- a/test/ir/inference/program_config.py +++ b/test/ir/inference/program_config.py @@ -115,6 +115,7 @@ _OP_WITHOUT_KERNEL_SET = { 'go', 'rnn_memory_helper_grad', 'conditional_block', + 'static_pylayer', 'while', 'send', 'recv', diff --git a/test/legacy_test/test_static_pylayer.py b/test/legacy_test/test_static_pylayer.py new file mode 100644 index 0000000000000000000000000000000000000000..2cd61de9517dfc02bb9c5a6a626ca212139c07d8 --- /dev/null +++ b/test/legacy_test/test_static_pylayer.py @@ -0,0 +1,327 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +from paddle import fluid +from paddle.fluid import core +from paddle.fluid.backward import append_backward +from paddle.fluid.framework import Program, program_guard + +np.random.seed(123) + + +class TestStatocPyLayerInputOutput(unittest.TestCase): + def test_return_single_var(self): + """ + pseudocode: + + y = 3 * x + """ + + paddle.enable_static() + + def forward_fn(x): + return 3 * x + + main_program = Program() + start_program = Program() + with program_guard(main_program, start_program): + data = paddle.static.data(name="X", shape=[1], dtype="float32") + out = paddle.static.nn.static_pylayer(forward_fn, [data]) + + place = ( + fluid.CUDAPlace(0) + if core.is_compiled_with_cuda() + else fluid.CPUPlace() + ) + exe = fluid.Executor(place) + x = np.array([2.0], dtype=np.float32) + (ret,) = exe.run(main_program, feed={"X": x}, fetch_list=[out.name]) + np.testing.assert_allclose( + np.asarray(ret), np.array([6.0], np.float32), rtol=1e-05 + ) + + # NOTE: Users should not be able to return none when actually using it. + def test_return_0d_tensor(self): + """ + pseudocode: + + y = 3 * x + """ + + paddle.enable_static() + + def forward_fn(x): + return 3 * x + + main_program = Program() + start_program = Program() + with program_guard(main_program, start_program): + data = paddle.full(shape=[], dtype='float32', fill_value=2.0) + out = paddle.static.nn.static_pylayer(forward_fn, [data]) + + place = ( + fluid.CUDAPlace(0) + if core.is_compiled_with_cuda() + else fluid.CPUPlace() + ) + exe = fluid.Executor(place) + (ret,) = exe.run(main_program, fetch_list=[out.name]) + np.testing.assert_allclose( + np.asarray(ret), np.array(6.0, np.float32), rtol=1e-05 + ) + self.assertEqual(ret.shape, ()) + + def test_0d_tensor_backward(self): + ''' + pseudocode: + + y = 3 * x + dx = -5 * dy + ''' + + paddle.enable_static() + + def forward_fn(x): + return 3 * x + + def backward_fn(dy): + return -5 * dy + + main_program = Program() + start_program = Program() + with program_guard(main_program, start_program): + data = paddle.full(shape=[], dtype='float32', fill_value=-2.0) + data.stop_gradient = False + out = paddle.static.nn.static_pylayer( + forward_fn, [data], backward_fn + ) + append_backward(out) + + place = ( + fluid.CUDAPlace(0) + if core.is_compiled_with_cuda() + else fluid.CPUPlace() + ) + exe = fluid.Executor(place) + ret, x_grad = exe.run( + main_program, fetch_list=[out.name, data.grad_name] + ) + np.testing.assert_allclose(np.asarray(ret), np.array(-6.0), rtol=1e-05) + self.assertEqual(ret.shape, ()) + + np.testing.assert_allclose( + np.asarray(x_grad), np.array(-5.0), rtol=1e-05 + ) + self.assertEqual(x_grad.shape, ()) + + def test_return_var_typle(self): + paddle.enable_static() + + def forward_fn(a, b): + return 3 * a, -2 * b + + main_program = Program() + start_program = Program() + with program_guard(main_program, start_program): + data_1 = paddle.full(shape=[2, 4], dtype='float32', fill_value=-2.0) + data_2 = paddle.full(shape=[4, 5], dtype='float32', fill_value=10.0) + out_1, out_2 = paddle.static.nn.static_pylayer( + forward_fn, [data_1, data_2] + ) + + place = ( + fluid.CUDAPlace(0) + if core.is_compiled_with_cuda() + else fluid.CPUPlace() + ) + exe = fluid.Executor(place) + ret_1, ret_2 = exe.run( + main_program, fetch_list=[out_1.name, out_2.name] + ) + np.testing.assert_allclose( + np.asarray(ret_1), + np.full((2, 4), -6.0, dtype=np.float32), + rtol=1e-05, + ) + + np.testing.assert_allclose( + np.asarray(ret_2), + np.full((4, 5), -20.0, dtype=np.float32), + rtol=1e-05, + ) + + def test_return_forward_none(self): + paddle.enable_static() + + input_shape = (1, 3) + + def forward_fn(x): + y = 3 * x + return None + + main_program = Program() + start_program = Program() + with program_guard(main_program, start_program): + data = paddle.full( + shape=input_shape, dtype='float32', fill_value=-2.0 + ) + out = paddle.static.nn.static_pylayer(forward_fn, [data]) + + place = ( + fluid.CUDAPlace(0) + if core.is_compiled_with_cuda() + else fluid.CPUPlace() + ) + exe = fluid.Executor(place) + exe.run(main_program) + self.assertIsNone(out) + + def test_wrong_structure_exception(self): + """ + test not all ``stop_gradient`` of inputs is True when ``backward_fn`` is None, and + wrong number of inputs and outputs returned by ``forward_fn`` and ``backward_fn`` + """ + + paddle.enable_static() + + def forward_fn(a, b): + return 3 * a, -b, paddle.mean(b) + + def backward_fn(daout, dbout): + return 3 * daout, -dbout + + main_program = Program() + startup_program = Program() + with program_guard(main_program, startup_program): + data_1 = paddle.static.data( + name="data_1", shape=[2, 4], dtype="float32" + ) + data_2 = paddle.static.data( + name="data_2", shape=[6], dtype="float32" + ) + data_2.stop_gradient = False + with self.assertRaises(ValueError) as e: + out = paddle.static.nn.static_pylayer( + forward_fn, [data_1, data_2], backward_fn=None + ) + self.assertTrue( + "``stop_gradient`` attr of all inputs to ``forward_fn`` are expected to be True, when ``backward_fn == None``" + in str(e.exception) + ) + + with self.assertRaises(TypeError) as e: + out = paddle.static.nn.static_pylayer( + forward_fn, [data_1, data_2], backward_fn=backward_fn + ) + + +# TODO(MarioLulab): Disable now. We will refine and add testcases later. +class _TestControlFlowNestedStaticPyLayer(unittest.TestCase): + # TODO(MarioLulab): failed when i >= 5, fix it later + def _test_cond_inside_static_pylayer(self): + """ + forward pass: + _ _ _ _ _ _ _ _ + ---> a ---> | | -----> out_i + | | StaticPyLayer | + i ---------> |_ _ _ _ _ _ _ _| -----> out ---> loss + + + pseudocode: + def forward_fn(i, a): + if i < 5: + return i, a + a + else: + return i, a - a + + def backward_fn(diout, daout): + if diout < 5: + return 2 * diout, daout * daout + else: + return 2 * diout, cos(daout) + """ + + paddle.enable_static() + + def forward_fn(i, a): + return i, paddle.static.nn.cond( + i < 5.0, lambda: paddle.add(a, a), lambda: paddle.subtract(a, a) + ) + + def backward_fn(diout, daout): + return 2 * diout, paddle.static.nn.cond( + diout < 5.0, + lambda: paddle.multiply(daout, daout), + lambda: paddle.cos(daout), + ) + + main_program = Program() + start_program = Program() + with program_guard(main_program, start_program): + i = paddle.static.data(name="i", shape=[1], dtype="float32") + i.stop_gradient = False + a = 2.0 * i + out_i, out = paddle.static.nn.static_pylayer( + forward_fn, [i, a], backward_fn + ) + loss = paddle.exp(out) + append_backward(loss) + + place = ( + fluid.CUDAPlace(0) + if core.is_compiled_with_cuda() + else fluid.CPUPlace() + ) + exe = fluid.Executor(place) + for feed_i in range(0, 10): + print(feed_i) + expected_a = 2.0 * feed_i + if feed_i < 5: + expected_out_i = feed_i + expected_out = expected_a + expected_a + expected_out_grad = np.exp(expected_out) + expected_a_grad = expected_out_grad * expected_out_grad + expected_i_grad = 2 * expected_a_grad + 0 + else: + expected_out_i = feed_i + expected_out = expected_a - expected_a + expected_out_grad = np.exp(expected_out) + expected_a_grad = np.cos(expected_out_grad) + expected_i_grad = 2 * expected_a_grad + 0 + ret = exe.run( + main_program, + feed={'i': np.full((1), feed_i, dtype=np.float32)}, + fetch_list=[out.name, out.grad_name, a.grad_name, i.grad_name], + ) + np.testing.assert_allclose( + np.asarray(ret[0]), expected_out, rtol=1e-05 + ) + np.testing.assert_allclose( + np.asarray(ret[1]), expected_out_grad, rtol=1e-05 + ) + np.testing.assert_allclose( + np.asarray(ret[2]), expected_a_grad, rtol=1e-05 + ) + np.testing.assert_allclose( + np.asarray(ret[3]), expected_i_grad, rtol=1e-05 + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/legacy_test/test_static_pylayer_block.py b/test/legacy_test/test_static_pylayer_block.py new file mode 100644 index 0000000000000000000000000000000000000000..060ea4d22d05c02c7c9bf050078fed90ba81aba5 --- /dev/null +++ b/test/legacy_test/test_static_pylayer_block.py @@ -0,0 +1,66 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +from paddle import fluid +from paddle.fluid import core +from paddle.static import Executor, append_backward +from paddle.static.nn.static_pylayer import StaticPyLayerBlock + + +class StaticPyLayerBlockTest(unittest.TestCase): + def test_forward_and_backward(self): + paddle.enable_static() + main_program = fluid.Program() + startup_program = fluid.Program() + with fluid.program_guard(main_program, startup_program): + data = paddle.static.data(name='X', shape=[10, 1], dtype='float32') + data.stop_gradient = False + static_pylayer_manager = StaticPyLayerBlock(inputs=[data]) + fwd_out = paddle.tensor.create_tensor(dtype='float32') + with static_pylayer_manager.block(is_backward_block=False) as mgr: + hidden_fwd = paddle.static.nn.fc(x=data, size=10) + paddle.assign(hidden_fwd, fwd_out) + mgr.fwd_outputs = [fwd_out] + + grad_name = data.name + core.grad_var_suffix() + with static_pylayer_manager.block(is_backward_block=True) as mgr: + constant_tensor = paddle.tensor.fill_constant( + shape=[10, 1], dtype="float32", value=2.0 + ) + mgr.var_old_to_new[constant_tensor.name] = grad_name + + cpu = core.CPUPlace() + exe = Executor(cpu) + exe.run(startup_program) + + x = np.random.random(size=(10, 1)).astype('float32') + outs = exe.run(main_program, feed={'X': x}, fetch_list=[fwd_out])[0] + print(outs) + loss = paddle.mean(fwd_out) + append_backward(loss=loss) + outs = exe.run( + main_program, + feed={'X': x}, + fetch_list=[data.grad_name], + )[0] + print(outs) + + +if __name__ == '__main__': + unittest.main()