未验证 提交 4bd28b30 编写于 作者: Q Qiyang Min 提交者: GitHub

Merge pull request #15831 from velconia/imperative_engine

Imperative training network to the end
...@@ -163,6 +163,20 @@ std::vector<OpDesc *> BlockDesc::AllOps() const { ...@@ -163,6 +163,20 @@ std::vector<OpDesc *> BlockDesc::AllOps() const {
return res; return res;
} }
void BlockDesc::Clear() {
// clear all ops
ops_.clear();
// clear all vars which are not persistable
for (auto it = vars_.begin(); it != vars_.end();) {
if (it->second->Persistable()) {
++it;
} else {
vars_.erase(it++);
}
}
}
void BlockDesc::Flush() { void BlockDesc::Flush() {
for (auto &op_desc : ops_) { for (auto &op_desc : ops_) {
op_desc->Flush(); op_desc->Flush();
......
...@@ -97,6 +97,8 @@ class BlockDesc { ...@@ -97,6 +97,8 @@ class BlockDesc {
std::vector<OpDesc *> AllOps() const; std::vector<OpDesc *> AllOps() const;
void Clear();
size_t OpSize() const { return ops_.size(); } size_t OpSize() const { return ops_.size(); }
OpDesc *Op(int idx) const { return ops_.at(idx).get(); } OpDesc *Op(int idx) const { return ops_.at(idx).get(); }
......
...@@ -114,23 +114,23 @@ class VarBase { ...@@ -114,23 +114,23 @@ class VarBase {
public: public:
VarBase() : VarBase(new framework::Variable(), new VarBase(true)) {} VarBase() : VarBase(new framework::Variable(), new VarBase(true)) {}
// Owns `var` and `grad` explicit VarBase(bool stop_gradient)
: VarBase(new framework::Variable(),
stop_gradient ? nullptr : new VarBase(true), stop_gradient) {}
VarBase(framework::Variable* var, VarBase* grad) VarBase(framework::Variable* var, VarBase* grad)
: VarBase(var, grad, false) {}
private:
VarBase(framework::Variable* var, VarBase* grad, bool stop_gradient)
: var_desc_(nullptr), : var_desc_(nullptr),
var_(var), var_(var),
grads_(grad), grads_(grad),
stop_gradient_(false),
pre_op_(nullptr),
pre_op_out_idx_(-1) {}
explicit VarBase(bool stop_gradient)
: var_desc_(nullptr),
var_(new framework::Variable()),
grads_(stop_gradient ? nullptr : new VarBase(true)),
stop_gradient_(stop_gradient), stop_gradient_(stop_gradient),
pre_op_(nullptr), pre_op_(nullptr),
pre_op_out_idx_(-1) {} pre_op_out_idx_(-1) {}
public:
virtual ~VarBase() { virtual ~VarBase() {
if (var_) { if (var_) {
delete var_; delete var_;
...@@ -141,11 +141,13 @@ class VarBase { ...@@ -141,11 +141,13 @@ class VarBase {
} }
} }
OpBase* PreOp() const { return pre_op_; } inline OpBase* PreOp() const { return pre_op_; }
int PreOpOutIdx() const { return pre_op_out_idx_; } inline int PreOpOutIdx() const { return pre_op_out_idx_; }
void SetStopGradient(bool stop_gradient) { stop_gradient_ = stop_gradient; } inline void SetStopGradient(bool stop_gradient) {
bool IsStopGradient() const { return stop_gradient_; } stop_gradient_ = stop_gradient;
}
inline bool IsStopGradient() const { return stop_gradient_; }
void RunBackward(); void RunBackward();
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#include "paddle/fluid/imperative/tracer.h" #include "paddle/fluid/imperative/tracer.h"
#include <set>
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -66,16 +68,18 @@ platform::Place GetExpectedPlace(platform::Place place, VarBasePtrMap inputs) { ...@@ -66,16 +68,18 @@ platform::Place GetExpectedPlace(platform::Place place, VarBasePtrMap inputs) {
return result; return result;
} }
void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
const VarBasePtrMap& outputs, framework::BlockDesc* block, const VarBasePtrMap& outputs,
const platform::Place expected_place, framework::BlockDesc* block,
const bool stop_gradient) { const platform::Place expected_place,
const bool stop_gradient) {
std::map<std::string, VarBase*> vars; std::map<std::string, VarBase*> vars;
framework::OpDesc* op_desc = op->op_desc_; framework::OpDesc* op_desc = op->op_desc_;
VLOG(3) << "tracer tracing " << op_desc->Type(); VLOG(3) << "tracer tracing " << op_desc->Type();
op_desc->InferShape(*block); op_desc->InferShape(*block);
op_desc->InferVarType(block); op_desc->InferVarType(block);
std::unique_ptr<framework::OperatorBase> op_base = std::unique_ptr<framework::OperatorBase> op_base =
framework::OpRegistry::CreateOp(*op_desc); framework::OpRegistry::CreateOp(*op_desc);
...@@ -92,7 +96,7 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, ...@@ -92,7 +96,7 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
invars.emplace_back(inp->var_); invars.emplace_back(inp->var_);
vars[inp->var_desc_->Name()] = inp; vars[inp->var_desc_->Name()] = inp;
if (inp->PreOp()) { if (inp->PreOp() && !inp->IsStopGradient()) {
op->pre_ops_[it.first].push_back(inp->PreOp()); op->pre_ops_[it.first].push_back(inp->PreOp());
op->pre_ops_out_idx_[it.first].push_back(inp->PreOpOutIdx()); op->pre_ops_out_idx_[it.first].push_back(inp->PreOpOutIdx());
} else { } else {
...@@ -142,6 +146,8 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, ...@@ -142,6 +146,8 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
framework::ExecutionContext(prepared_op.op, scope, *prepared_op.dev_ctx, framework::ExecutionContext(prepared_op.op, scope, *prepared_op.dev_ctx,
prepared_op.ctx, prepared_op.kernel_configs)); prepared_op.ctx, prepared_op.kernel_configs));
std::set<std::string> vars_saved_for_backward;
if (!stop_gradient) { if (!stop_gradient) {
std::unique_ptr<std::unordered_map<std::string, std::string>> grad_to_var( std::unique_ptr<std::unordered_map<std::string, std::string>> grad_to_var(
new std::unordered_map<std::string, std::string>()); new std::unordered_map<std::string, std::string>());
...@@ -161,6 +167,7 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, ...@@ -161,6 +167,7 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
PADDLE_ENFORCE(fwd_var_it != vars.end()); PADDLE_ENFORCE(fwd_var_it != vars.end());
// Forward inputs or outputs. // Forward inputs or outputs.
grad_in_vars.push_back(fwd_var_it->second->var_); grad_in_vars.push_back(fwd_var_it->second->var_);
vars_saved_for_backward.insert(it.first);
} else { } else {
VarBase* var = vars[var_it->second]; VarBase* var = vars[var_it->second];
if (!var->grads_->var_->IsInitialized()) { if (!var->grads_->var_->IsInitialized()) {
...@@ -194,6 +201,7 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, ...@@ -194,6 +201,7 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
} }
op->block_ = block; op->block_ = block;
return vars_saved_for_backward;
} }
std::vector<VarBase*> Tracer::PyTrace(OpBase* op, std::vector<VarBase*> Tracer::PyTrace(OpBase* op,
...@@ -203,7 +211,7 @@ std::vector<VarBase*> Tracer::PyTrace(OpBase* op, ...@@ -203,7 +211,7 @@ std::vector<VarBase*> Tracer::PyTrace(OpBase* op,
op->input_vars_[PyLayer::kFwdInp] = inputs; op->input_vars_[PyLayer::kFwdInp] = inputs;
op->output_vars_[PyLayer::kFwdOut] = PyLayer::Apply(op->forward_id_, inputs); op->output_vars_[PyLayer::kFwdOut] = PyLayer::Apply(op->forward_id_, inputs);
for (VarBase* inp : inputs) { for (VarBase* inp : inputs) {
if (inp->PreOp()) { if (inp->PreOp() && !inp->IsStopGradient()) {
op->pre_ops_[PyLayer::kFwdInp].push_back(inp->PreOp()); op->pre_ops_[PyLayer::kFwdInp].push_back(inp->PreOp());
op->pre_ops_out_idx_[PyLayer::kFwdInp].push_back(inp->PreOpOutIdx()); op->pre_ops_out_idx_[PyLayer::kFwdInp].push_back(inp->PreOpOutIdx());
} else { } else {
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include <map> #include <map>
#include <set>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -43,10 +44,11 @@ class Tracer { ...@@ -43,10 +44,11 @@ class Tracer {
virtual ~Tracer() {} virtual ~Tracer() {}
void Trace(OpBase* op, const VarBasePtrMap& inputs, std::set<std::string> Trace(OpBase* op, const VarBasePtrMap& inputs,
const VarBasePtrMap& outputs, framework::BlockDesc* block, const VarBasePtrMap& outputs,
const platform::Place expected_place, framework::BlockDesc* block,
const bool stop_gradient = false); const platform::Place expected_place,
const bool stop_gradient = false);
std::vector<VarBase*> PyTrace(OpBase* op, const std::vector<VarBase*>& inputs, std::vector<VarBase*> PyTrace(OpBase* op, const std::vector<VarBase*>& inputs,
bool stop_gradient = false); bool stop_gradient = false);
......
...@@ -34,8 +34,8 @@ void BindTracer(pybind11::module* m) { ...@@ -34,8 +34,8 @@ void BindTracer(pybind11::module* m) {
framework::BlockDesc* block, framework::BlockDesc* block,
const platform::CPUPlace expected_place, const platform::CPUPlace expected_place,
const bool stop_gradient = false) { const bool stop_gradient = false) {
self.Trace(op, inputs, outputs, block, expected_place, return self.Trace(op, inputs, outputs, block, expected_place,
stop_gradient); stop_gradient);
}) })
.def("trace", .def("trace",
[](imperative::Tracer& self, imperative::OpBase* op, [](imperative::Tracer& self, imperative::OpBase* op,
...@@ -44,8 +44,8 @@ void BindTracer(pybind11::module* m) { ...@@ -44,8 +44,8 @@ void BindTracer(pybind11::module* m) {
framework::BlockDesc* block, framework::BlockDesc* block,
const platform::CUDAPlace expected_place, const platform::CUDAPlace expected_place,
const bool stop_gradient = false) { const bool stop_gradient = false) {
self.Trace(op, inputs, outputs, block, expected_place, return self.Trace(op, inputs, outputs, block, expected_place,
stop_gradient); stop_gradient);
}) })
.def("py_trace", &imperative::Tracer::PyTrace, .def("py_trace", &imperative::Tracer::PyTrace,
pybind11::return_value_policy::take_ownership); pybind11::return_value_policy::take_ownership);
......
...@@ -189,6 +189,8 @@ void BindBlockDesc(pybind11::module *m) { ...@@ -189,6 +189,8 @@ void BindBlockDesc(pybind11::module *m) {
return self.HasVar(name); return self.HasVar(name);
}, },
pybind11::return_value_policy::reference) pybind11::return_value_policy::reference)
.def("_clear_block", [](pd::BlockDesc &self) { return self.Clear(); },
pybind11::return_value_policy::reference)
.def("_rename_var", .def("_rename_var",
[](pd::BlockDesc &self, const pybind11::bytes &byte_name, [](pd::BlockDesc &self, const pybind11::bytes &byte_name,
const pybind11::bytes &byte_name_new) { const pybind11::bytes &byte_name_new) {
......
...@@ -387,16 +387,19 @@ class Variable(object): ...@@ -387,16 +387,19 @@ class Variable(object):
# get_capacity is implemented # get_capacity is implemented
pass pass
self.block.vars[name] = self
self.op = None
self.stop_gradient = stop_gradient
self.is_data = is_data
if _in_imperative_mode(): if _in_imperative_mode():
# record vars in tracer rather than blocks
self._ivar = kwargs.get("ivar", None) self._ivar = kwargs.get("ivar", None)
if not self._ivar: if not self._ivar:
self._ivar = core.VarBase() self._ivar = core.VarBase(stop_gradient)
self._ivar.desc = self.desc self._ivar.desc = self.desc
self._ivar.stop_gradient = stop_gradient if persistable:
self.block.vars[name] = self
else:
self.block.vars[name] = self
self.op = None
self.stop_gradient = stop_gradient
self.is_data = is_data
def _numpy(self): def _numpy(self):
new_ivar = self._ivar._copy_to(core.CPUPlace(), True) new_ivar = self._ivar._copy_to(core.CPUPlace(), True)
...@@ -739,6 +742,7 @@ class Operator(object): ...@@ -739,6 +742,7 @@ class Operator(object):
if _in_imperative_mode(): if _in_imperative_mode():
self.iop = core.OpBase() self.iop = core.OpBase()
self.iop.desc = self.desc self.iop.desc = self.desc
self.inputs = defaultdict(list) self.inputs = defaultdict(list)
if inputs is not None: if inputs is not None:
for k, v in six.iteritems(inputs): for k, v in six.iteritems(inputs):
...@@ -746,6 +750,7 @@ class Operator(object): ...@@ -746,6 +750,7 @@ class Operator(object):
self.inputs[k].append(v._ivar) self.inputs[k].append(v._ivar)
elif isinstance(v, list) or isinstance(v, tuple): elif isinstance(v, list) or isinstance(v, tuple):
self.inputs[k].extend([var._ivar for var in v]) self.inputs[k].extend([var._ivar for var in v])
self.outputs = defaultdict(list) self.outputs = defaultdict(list)
if outputs is not None: if outputs is not None:
for k, v in six.iteritems(outputs): for k, v in six.iteritems(outputs):
...@@ -1195,6 +1200,15 @@ class Block(object): ...@@ -1195,6 +1200,15 @@ class Block(object):
else: else:
raise ValueError("Var {0} is not found recursively".format(name)) raise ValueError("Var {0} is not found recursively".format(name))
def _clear_block(self):
# TODO(minqiyang): move this to backward_hooks
self.desc._clear_block()
for name in self.vars.keys():
assert self.vars[name].persistable
del self.ops[:]
def all_parameters(self): def all_parameters(self):
return list(self.iter_parameters()) return list(self.iter_parameters())
...@@ -1325,18 +1339,31 @@ class Block(object): ...@@ -1325,18 +1339,31 @@ class Block(object):
inputs=kwargs.get("inputs", None), inputs=kwargs.get("inputs", None),
outputs=kwargs.get("outputs", None), outputs=kwargs.get("outputs", None),
attrs=kwargs.get("attrs", None)) attrs=kwargs.get("attrs", None))
if _in_imperative_mode():
# record ops in tracer rather than blocks
#
# TODO(minqiyang): add op stop_gradient support in static mode too.
# currently, we only support stop_gradient in imperative mode.
self._trace_op(op, kwargs.get("stop_gradient", False))
self.ops.append(op) self.ops.append(op)
# TODO(minqiyang): add stop_gradient support in static mode too.
# currently, we only support stop_gradient in imperative mode.
self._trace_op(op, kwargs.get("stop_gradient", False))
return op return op
def _trace_op(self, op, stop_gradient=False): def _trace_op(self, op, stop_gradient=False):
if _in_imperative_mode(): backward_refs = _imperative_tracer().trace(
_imperative_tracer().trace(op.iop, op.inputs, op.outputs, self.desc, op.iop, op.inputs, op.outputs, self.desc,
_imperative_current_expected_place_, _imperative_current_expected_place_, stop_gradient)
stop_gradient)
# TODO(minqiyang): support backward_hooks to eager remove backward_refs
op.backward_refs = defaultdict(list)
for k, v in six.iteritems(op.inputs):
if k in backward_refs:
op.backward_refs[k] = op.inputs[k]
for k, v in six.iteritems(op.outputs):
if k in backward_refs:
op.backward_refs[k] = op.outputs[k]
def _insert_op(self, index, *args, **kwargs): def _insert_op(self, index, *args, **kwargs):
""" """
...@@ -1391,7 +1418,8 @@ class Block(object): ...@@ -1391,7 +1418,8 @@ class Block(object):
outputs=kwargs.get("outputs", None), outputs=kwargs.get("outputs", None),
attrs=kwargs.get("attrs", None)) attrs=kwargs.get("attrs", None))
self.ops.insert(0, op) self.ops.insert(0, op)
self._trace_op(op, kwargs.get("stop_gradient", False)) if _in_imperative_mode():
self._trace_op(op, kwargs.get("stop_gradient", False))
return op return op
def _sync_with_cpp(self): def _sync_with_cpp(self):
......
...@@ -105,7 +105,7 @@ class MNIST(fluid.imperative.Layer): ...@@ -105,7 +105,7 @@ class MNIST(fluid.imperative.Layer):
class TestImperativeMnist(unittest.TestCase): class TestImperativeMnist(unittest.TestCase):
def test_mnist_float32(self): def test_mnist_float32(self):
seed = 90 seed = 90
batch_num = 2 epoch_num = 1
with fluid.imperative.guard(): with fluid.imperative.guard():
fluid.default_startup_program().random_seed = seed fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed fluid.default_main_program().random_seed = seed
...@@ -113,39 +113,40 @@ class TestImperativeMnist(unittest.TestCase): ...@@ -113,39 +113,40 @@ class TestImperativeMnist(unittest.TestCase):
mnist = MNIST("mnist") mnist = MNIST("mnist")
sgd = SGDOptimizer(learning_rate=1e-3) sgd = SGDOptimizer(learning_rate=1e-3)
train_reader = paddle.batch( train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=128) paddle.dataset.mnist.train(), batch_size=128, drop_last=True)
dy_param_init_value = {} dy_param_init_value = {}
for batch_id, data in enumerate(train_reader()): for epoch in range(epoch_num):
if batch_id >= batch_num: for batch_id, data in enumerate(train_reader()):
break dy_x_data = np.array(
[x[0].reshape(1, 28, 28)
dy_x_data = np.array( for x in data]).astype('float32')
[x[0].reshape(1, 28, 28) for x in data]).astype('float32') y_data = np.array(
y_data = np.array([x[1] for x in data]).astype('int64').reshape( [x[1] for x in data]).astype('int64').reshape(128, 1)
128, 1)
img = to_variable(dy_x_data)
img = to_variable(dy_x_data) label = to_variable(y_data)
label = to_variable(y_data) label._stop_gradient = True
label._stop_gradient = True
cost = mnist(img)
cost = mnist(img) loss = fluid.layers.cross_entropy(cost, label)
loss = fluid.layers.cross_entropy(cost, label) avg_loss = fluid.layers.mean(loss)
avg_loss = fluid.layers.mean(loss)
dy_out = avg_loss._numpy() dy_out = avg_loss._numpy()
if batch_id == 0: if epoch == 0 and batch_id == 0:
for param in fluid.default_main_program().global_block( for param in mnist.parameters():
).all_parameters(): dy_param_init_value[param.name] = param._numpy()
dy_param_init_value[param.name] = param._numpy()
avg_loss._backward()
avg_loss._backward() sgd.minimize(avg_loss)
sgd.minimize(avg_loss) mnist.clear_gradients()
mnist.clear_gradients()
dy_param_value = {} fluid.default_main_program().global_block()._clear_block()
for param in fluid.default_main_program().global_block(
).all_parameters(): dy_param_value = {}
dy_param_value[param.name] = param._numpy() for param in mnist.parameters():
dy_param_value[param.name] = param._numpy()
with new_program_scope(): with new_program_scope():
fluid.default_startup_program().random_seed = seed fluid.default_startup_program().random_seed = seed
...@@ -157,7 +158,7 @@ class TestImperativeMnist(unittest.TestCase): ...@@ -157,7 +158,7 @@ class TestImperativeMnist(unittest.TestCase):
mnist = MNIST("mnist") mnist = MNIST("mnist")
sgd = SGDOptimizer(learning_rate=1e-3) sgd = SGDOptimizer(learning_rate=1e-3)
train_reader = paddle.batch( train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=128) paddle.dataset.mnist.train(), batch_size=128, drop_last=True)
img = fluid.layers.data( img = fluid.layers.data(
name='pixel', shape=[1, 28, 28], dtype='float32') name='pixel', shape=[1, 28, 28], dtype='float32')
...@@ -170,8 +171,7 @@ class TestImperativeMnist(unittest.TestCase): ...@@ -170,8 +171,7 @@ class TestImperativeMnist(unittest.TestCase):
# initialize params and fetch them # initialize params and fetch them
static_param_init_value = {} static_param_init_value = {}
static_param_name_list = [] static_param_name_list = []
for param in fluid.default_startup_program().global_block( for param in mnist.parameters():
).all_parameters():
static_param_name_list.append(param.name) static_param_name_list.append(param.name)
out = exe.run(fluid.default_startup_program(), out = exe.run(fluid.default_startup_program(),
...@@ -180,26 +180,29 @@ class TestImperativeMnist(unittest.TestCase): ...@@ -180,26 +180,29 @@ class TestImperativeMnist(unittest.TestCase):
for i in range(len(static_param_name_list)): for i in range(len(static_param_name_list)):
static_param_init_value[static_param_name_list[i]] = out[i] static_param_init_value[static_param_name_list[i]] = out[i]
for batch_id, data in enumerate(train_reader()): for epoch in range(epoch_num):
if batch_id >= batch_num: for batch_id, data in enumerate(train_reader()):
break static_x_data = np.array(
[x[0].reshape(1, 28, 28)
static_x_data = np.array( for x in data]).astype('float32')
[x[0].reshape(1, 28, 28) for x in data]).astype('float32') y_data = np.array(
y_data = np.array([x[1] for x in data]).astype('int64').reshape( [x[1] for x in data]).astype('int64').reshape([128, 1])
[128, 1])
fetch_list = [avg_loss.name]
fetch_list = [avg_loss.name] fetch_list.extend(static_param_name_list)
fetch_list.extend(static_param_name_list) out = exe.run(
out = exe.run(fluid.default_main_program(), fluid.default_main_program(),
feed={"pixel": static_x_data, feed={"pixel": static_x_data,
"label": y_data}, "label": y_data},
fetch_list=fetch_list) fetch_list=fetch_list)
static_param_value = {} static_param_value = {}
static_out = out[0] static_out = out[0]
for i in range(1, len(out)): for i in range(1, len(out)):
static_param_value[static_param_name_list[i - 1]] = out[i] static_param_value[static_param_name_list[i - 1]] = out[
i]
self.assertTrue(np.allclose(dy_x_data.all(), static_x_data.all()))
for key, value in six.iteritems(static_param_init_value): for key, value in six.iteritems(static_param_init_value):
self.assertTrue(np.allclose(value, dy_param_init_value[key])) self.assertTrue(np.allclose(value, dy_param_init_value[key]))
...@@ -207,7 +210,7 @@ class TestImperativeMnist(unittest.TestCase): ...@@ -207,7 +210,7 @@ class TestImperativeMnist(unittest.TestCase):
self.assertTrue(np.allclose(static_out, dy_out)) self.assertTrue(np.allclose(static_out, dy_out))
for key, value in six.iteritems(static_param_value): for key, value in six.iteritems(static_param_value):
self.assertTrue(np.allclose(value, dy_param_value[key])) self.assertTrue(np.allclose(value, dy_param_value[key], atol=1e-5))
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -231,7 +231,7 @@ class TestImperativeResnet(unittest.TestCase): ...@@ -231,7 +231,7 @@ class TestImperativeResnet(unittest.TestCase):
seed = 90 seed = 90
batch_size = train_parameters["batch_size"] batch_size = train_parameters["batch_size"]
batch_num = 1 batch_num = 2
with fluid.imperative.guard(): with fluid.imperative.guard():
fluid.default_startup_program().random_seed = seed fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed fluid.default_main_program().random_seed = seed
...@@ -286,6 +286,8 @@ class TestImperativeResnet(unittest.TestCase): ...@@ -286,6 +286,8 @@ class TestImperativeResnet(unittest.TestCase):
optimizer.minimize(avg_loss) optimizer.minimize(avg_loss)
resnet.clear_gradients() resnet.clear_gradients()
fluid.default_main_program().global_block()._clear_block()
dy_param_value = {} dy_param_value = {}
for param in resnet.parameters(): for param in resnet.parameters():
dy_param_value[param.name] = param._numpy() dy_param_value[param.name] = param._numpy()
...@@ -319,11 +321,9 @@ class TestImperativeResnet(unittest.TestCase): ...@@ -319,11 +321,9 @@ class TestImperativeResnet(unittest.TestCase):
static_param_init_value = {} static_param_init_value = {}
static_param_name_list = [] static_param_name_list = []
static_grad_name_list = [] static_grad_name_list = []
for param in fluid.default_startup_program().global_block( for param in resnet.parameters():
).all_parameters():
static_param_name_list.append(param.name) static_param_name_list.append(param.name)
for param in fluid.default_main_program().global_block( for param in resnet.parameters():
).all_parameters():
if not param.stop_gradient: if not param.stop_gradient:
static_grad_name_list.append(param.name + static_grad_name_list.append(param.name +
core.grad_var_suffix()) core.grad_var_suffix())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册