“142e832d21715c0ce651e4ac04f10554945e5ad7”上不存在“paddle/cinn/optim/replace_const_param_to_integer.h”
未验证 提交 a0631364 编写于 作者: X xiongkun 提交者: GitHub

Fix test calc gradient (#37672)

* add scope_guard

* 1. fix control flow cases 2. fix calc_gradient
上级 74fdba7c
...@@ -16,6 +16,9 @@ ...@@ -16,6 +16,9 @@
#include "paddle/fluid/framework/executor_gc_helper.h" #include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/new_executor/data_transfer.h" #include "paddle/fluid/framework/new_executor/data_transfer.h"
#include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h"
#include "paddle/fluid/operators/controlflow/recurrent_op_helper.h"
#include "paddle/fluid/operators/controlflow/while_op_helper.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -127,6 +130,9 @@ void build_variable_scope(const framework::BlockDesc& block, ...@@ -127,6 +130,9 @@ void build_variable_scope(const framework::BlockDesc& block,
for (auto& var_desc : block.AllVars()) { for (auto& var_desc : block.AllVars()) {
auto var_name = var_desc->Name(); auto var_name = var_desc->Name();
// TODO(xiongkun): user may create a variable with name that exists before.
// under such circumstances, we should raise a error. Currently we can't
// get the var_desc of startup_program, so leave it later.
if (var_name == framework::kEmptyVarName) { if (var_name == framework::kEmptyVarName) {
continue; continue;
} }
...@@ -149,7 +155,7 @@ void build_variable_scope(const framework::BlockDesc& block, ...@@ -149,7 +155,7 @@ void build_variable_scope(const framework::BlockDesc& block,
} }
void create_all_ops(const framework::BlockDesc& block, void create_all_ops(const framework::BlockDesc& block,
std::vector<std::shared_ptr<OperatorBase>>* ops) { std::vector<std::unique_ptr<OperatorBase>>* ops) {
for (auto& op : block.AllOps()) { for (auto& op : block.AllOps()) {
VLOG(3) << "CreateOp from : " << op->Type(); VLOG(3) << "CreateOp from : " << op->Type();
...@@ -164,7 +170,7 @@ void create_all_ops(const framework::BlockDesc& block, ...@@ -164,7 +170,7 @@ void create_all_ops(const framework::BlockDesc& block,
} }
auto op_base = auto op_base =
info.Creator()(op->Type(), inputs_names, outputs_names, op_attr_map); info.Creator()(op->Type(), inputs_names, outputs_names, op_attr_map);
ops->emplace_back(std::shared_ptr<OperatorBase>(op_base)); ops->emplace_back(std::unique_ptr<OperatorBase>(op_base));
} }
} }
...@@ -260,10 +266,24 @@ void build_op_func_list(const platform::Place& place, ...@@ -260,10 +266,24 @@ void build_op_func_list(const platform::Place& place,
Scope* local_scope = use_local_scope ? var_scope->GetMutableLocalScope() Scope* local_scope = use_local_scope ? var_scope->GetMutableLocalScope()
: var_scope->GetMutableScope(); : var_scope->GetMutableScope();
auto& all_op_kernels = OperatorWithKernel::AllOpKernels(); auto& all_op_kernels = OperatorWithKernel::AllOpKernels();
std::vector<std::unique_ptr<OperatorBase>>
ops_unique; // its elements will be moved to vec_func_list
// Step 1: create all ops for current block.
create_all_ops(block, &ops_unique);
// If gc is enabled and block size > 1
const ProgramDesc& main_program = *block.Program();
operators::PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
main_program, block.ID(), ops_unique);
operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
main_program, block.ID(), ops_unique);
operators::PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(
main_program, block.ID(), ops_unique);
std::vector<std::shared_ptr<OperatorBase>> std::vector<std::shared_ptr<OperatorBase>>
ops; // its elements will be moved to vec_func_list ops; // its elements will be moved to vec_func_list
// Step 1: create all ops for current block. for (auto& op_unique : ops_unique) {
create_all_ops(block, &ops); ops.emplace_back(std::move(op_unique));
}
auto unused_var_map = get_unused_vars(block, ops); auto unused_var_map = get_unused_vars(block, ops);
for (size_t i = 0; i < ops.size(); ++i) { for (size_t i = 0; i < ops.size(); ++i) {
......
...@@ -33,6 +33,7 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place, ...@@ -33,6 +33,7 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place,
if (scope) { if (scope) {
auto name_list = scope->LocalVarNames(); auto name_list = scope->LocalVarNames();
for (auto name : name_list) { for (auto name : name_list) {
VLOG(4) << "Sync Variable from variable scope: " << name;
auto v = scope->Var(name); auto v = scope->Var(name);
if (!global_scope_.HasVar(name)) { if (!global_scope_.HasVar(name)) {
global_scope_.AddVar(name, *v); global_scope_.AddVar(name, *v);
...@@ -87,8 +88,9 @@ void StandaloneExecutor::BuildVariableScope(const framework::ProgramDesc& pdesc, ...@@ -87,8 +88,9 @@ void StandaloneExecutor::BuildVariableScope(const framework::ProgramDesc& pdesc,
if (var->Name() == framework::kEmptyVarName) { if (var->Name() == framework::kEmptyVarName) {
continue; continue;
} }
if (!var_scope->HasVar(var->Name())) { if (!var_scope->HasVar(var->Name())) {
VLOG(4) << "Create variable from startup_prog: "
<< var->Proto()->SerializeAsString();
var_scope->AddVar(var->Name(), var); var_scope->AddVar(var->Name(), var);
} }
} }
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
from __future__ import print_function from __future__ import print_function
import paddle
import unittest import unittest
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -83,6 +84,7 @@ class TestDoubleGrad(unittest.TestCase): ...@@ -83,6 +84,7 @@ class TestDoubleGrad(unittest.TestCase):
class TestGradientWithPrune(unittest.TestCase): class TestGradientWithPrune(unittest.TestCase):
def test_prune(self): def test_prune(self):
with paddle.fluid.scope_guard(paddle.static.Scope()):
x = fluid.data(name='x', shape=[3], dtype='float32') x = fluid.data(name='x', shape=[3], dtype='float32')
x.stop_gradient = False x.stop_gradient = False
x1, x2, x3 = fluid.layers.split(x, dim=0, num_or_sections=3) x1, x2, x3 = fluid.layers.split(x, dim=0, num_or_sections=3)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册