未验证 提交 102fc859 编写于 作者: Q Qiyang Min 提交者: GitHub

Merge pull request #16777 from velconia/dygraph_untrack_op

Imperative tracer does not hold op any more
...@@ -48,6 +48,12 @@ class Layer(core.Layer): ...@@ -48,6 +48,12 @@ class Layer(core.Layer):
self._helper = LayerObjectHelper(self._full_name) self._helper = LayerObjectHelper(self._full_name)
def train(self):
framework._dygraph_tracer()._train_mode()
def eval(self):
framework._dygraph_tracer()._eval_mode()
def full_name(self): def full_name(self):
"""Full name for this layers. """Full name for this layers.
...@@ -254,6 +260,12 @@ class PyLayer(core.PyLayer): ...@@ -254,6 +260,12 @@ class PyLayer(core.PyLayer):
def __init__(self): def __init__(self):
super(PyLayer, self).__init__() super(PyLayer, self).__init__()
def train(self):
framework._dygraph_tracer()._train_mode()
def eval(self):
framework._dygraph_tracer()._eval_mode()
@classmethod @classmethod
def _do_forward(cls, inputs): def _do_forward(cls, inputs):
return cls._to_tuple(cls.forward(inputs)) return cls._to_tuple(cls.forward(inputs))
......
...@@ -24,7 +24,9 @@ __all__ = ['Tracer'] ...@@ -24,7 +24,9 @@ __all__ = ['Tracer']
def release_op(op): def release_op(op):
del framework._dygraph_tracer()._ops[op._trace_id] del framework._dygraph_tracer()._ops[op._trace_id].inputs
del framework._dygraph_tracer()._ops[op._trace_id].outputs
del framework._dygraph_tracer()._ops[op._trace_id].backward_refs
class Tracer(core.Tracer): class Tracer(core.Tracer):
...@@ -38,6 +40,7 @@ class Tracer(core.Tracer): ...@@ -38,6 +40,7 @@ class Tracer(core.Tracer):
self._ops = defaultdict() self._ops = defaultdict()
self._vars = defaultdict() self._vars = defaultdict()
self._trace_id = 0 self._trace_id = 0
self._train_mode = True
def trace_var(self, name, var): def trace_var(self, name, var):
self._vars[name] = var self._vars[name] = var
...@@ -46,15 +49,57 @@ class Tracer(core.Tracer): ...@@ -46,15 +49,57 @@ class Tracer(core.Tracer):
return list((item for name, item in six.iteritems(self._vars) return list((item for name, item in six.iteritems(self._vars)
if isinstance(item, framework.Parameter))) if isinstance(item, framework.Parameter)))
def trace_op(self, op, stop_gradient=False): def trace_op(self, op, inputs, outputs, stop_gradient=False):
# TODO(minqiyang): remove this line after we take apart all
# backward grads and forward variables
if self._train_mode:
op.inputs = inputs
inps = defaultdict(list)
for k, vars in six.iteritems(inputs):
if isinstance(vars, framework.Variable):
inps[k].append(vars._ivar)
elif isinstance(vars, list) or isinstance(vars, tuple):
for var in vars:
inps[k].append(var._ivar)
op.outputs = outputs
outs = defaultdict(list)
for k, vars in six.iteritems(outputs):
if isinstance(vars, framework.Variable):
outs[k].append(vars._ivar)
elif isinstance(vars, list) or isinstance(vars, tuple):
for var in vars:
outs[k].append(var._ivar)
else:
inps = defaultdict(list)
for k, vars in six.iteritems(inputs):
if isinstance(vars, framework.Variable):
op.previous_ops.append(vars.op)
inps[k].append(vars._ivar)
elif isinstance(vars, list) or isinstance(vars, tuple):
for var in vars:
op.previous_ops.append(var.op)
inps[k].append(var._ivar)
op.outputs = outputs
outs = defaultdict(list)
for k, vars in six.iteritems(outputs):
if isinstance(vars, framework.Variable):
vars.op = op
outs[k].append(vars._ivar)
elif isinstance(vars, list) or isinstance(vars, tuple):
for var in vars:
var.op = op
outs[k].append(var._ivar)
# record op's trace id # record op's trace id
op.iop._trace_id = self._trace_id op.iop._trace_id = self._trace_id
backward_refs = self.trace(op.iop, op.inputs, op.outputs, op.attrs, backward_refs = self.trace(op.iop, inps, outs, op.attrs,
framework._current_expected_place(), framework._current_expected_place(),
stop_gradient) stop_gradient)
if not stop_gradient: if not stop_gradient and self._train_mode:
self._trace_id += 1 self._trace_id += 1
self._ops[op.iop._trace_id] = op self._ops[op.iop._trace_id] = op
...@@ -65,10 +110,16 @@ class Tracer(core.Tracer): ...@@ -65,10 +110,16 @@ class Tracer(core.Tracer):
# TODO(minqiyang): remove all inputs and outputs after separate # TODO(minqiyang): remove all inputs and outputs after separate
# var and grad # var and grad
op.backward_refs = defaultdict(list) op.backward_refs = defaultdict(list)
for k, v in six.iteritems(op.inputs): for k, v in six.iteritems(inputs):
if k in backward_refs: if k in backward_refs:
op.backward_refs[k] = op.inputs[k] op.backward_refs[k] = inputs[k]
for k, v in six.iteritems(op.outputs): for k, v in six.iteritems(outputs):
if k in backward_refs: if k in backward_refs:
op.backward_refs[k] = op.outputs[k] op.backward_refs[k] = outputs[k]
def _train_mode(self):
self._train_mode = True
def _eval_mode(self):
self._train_mode = False
...@@ -411,6 +411,7 @@ class Variable(object): ...@@ -411,6 +411,7 @@ class Variable(object):
if persistable else False) if persistable else False)
if persistable: if persistable:
_dygraph_tracer().trace_var(name, self) _dygraph_tracer().trace_var(name, self)
self.op = None
else: else:
self.error_clip = error_clip self.error_clip = error_clip
...@@ -939,24 +940,7 @@ class Operator(object): ...@@ -939,24 +940,7 @@ class Operator(object):
raise ValueError( raise ValueError(
"`type` to initialized an Operator can not be None.") "`type` to initialized an Operator can not be None.")
self.iop = core.OpBase(type) self.iop = core.OpBase(type)
self.previous_ops = []
# TODO(minqiyang): remove these lines after we take apart all
# backward grads and forward variables
self.inputs = defaultdict(list)
if inputs is not None:
for k, v in six.iteritems(inputs):
if isinstance(v, Variable):
self.inputs[k].append(v._ivar)
elif isinstance(v, list) or isinstance(v, tuple):
self.inputs[k].extend([var._ivar for var in v])
self.outputs = defaultdict(list)
if outputs is not None:
for k, v in six.iteritems(outputs):
if isinstance(v, Variable):
self.outputs[k].append(v._ivar)
elif isinstance(v, list) or isinstance(v, tuple):
self.outputs[k].extend([var._ivar for var in v])
self.attrs = attrs if attrs else {} self.attrs = attrs if attrs else {}
else: else:
...@@ -1647,15 +1631,18 @@ class Block(object): ...@@ -1647,15 +1631,18 @@ class Block(object):
block=self, block=self,
desc=None, desc=None,
type=kwargs.get("type", None), type=kwargs.get("type", None),
inputs=kwargs.get("inputs", None), inputs=None,
outputs=kwargs.get("outputs", None), outputs=None,
attrs=kwargs.get("attrs", None)) attrs=kwargs.get("attrs", {}))
# record ops in tracer rather than blocks # record ops in tracer rather than blocks
# #
# TODO(minqiyang): add op stop_gradient support in static mode too. # TODO(minqiyang): add op stop_gradient support in static mode too.
# currently, we only support stop_gradient in dygraph mode. # currently, we only support stop_gradient in dygraph mode.
_dygraph_tracer().trace_op(op, kwargs.get("stop_gradient", False)) _dygraph_tracer().trace_op(op,
kwargs.get("inputs", {}),
kwargs.get("outputs", {}),
kwargs.get("stop_gradient", False))
else: else:
op_desc = self.desc.append_op() op_desc = self.desc.append_op()
op = Operator( op = Operator(
...@@ -1719,10 +1706,14 @@ class Block(object): ...@@ -1719,10 +1706,14 @@ class Block(object):
self, self,
None, None,
type=kwargs.get("type", None), type=kwargs.get("type", None),
inputs=kwargs.get("inputs", None), inputs=None,
outputs=kwargs.get("outputs", None), outputs=None,
attrs=kwargs.get("attrs", None)) attrs=kwargs.get("attrs", {}))
_dygraph_tracer().trace_op(op, kwargs.get("stop_gradient", False))
_dygraph_tracer().trace_op(op,
kwargs.get("inputs", {}),
kwargs.get("outputs", {}),
kwargs.get("stop_gradient", False))
else: else:
op_desc = self.desc._prepend_op() op_desc = self.desc._prepend_op()
op = Operator( op = Operator(
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import contextlib
import unittest
import numpy as np
import six
import paddle
import paddle.fluid as fluid
from paddle.fluid import core
from paddle.fluid.optimizer import SGDOptimizer
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, FC
from paddle.fluid.dygraph.base import to_variable
from test_imperative_base import new_program_scope
class SimpleImgConvPool(fluid.dygraph.Layer):
def __init__(self,
name_scope,
num_channels,
num_filters,
filter_size,
pool_size,
pool_stride,
pool_padding=0,
pool_type='max',
global_pooling=False,
conv_stride=1,
conv_padding=0,
conv_dilation=1,
conv_groups=1,
act=None,
use_cudnn=False,
param_attr=None,
bias_attr=None):
super(SimpleImgConvPool, self).__init__(name_scope)
self._conv2d = Conv2D(
self.full_name(),
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
stride=conv_stride,
padding=conv_padding,
dilation=conv_dilation,
groups=conv_groups,
param_attr=None,
bias_attr=None,
use_cudnn=use_cudnn)
self._pool2d = Pool2D(
self.full_name(),
pool_size=pool_size,
pool_type=pool_type,
pool_stride=pool_stride,
pool_padding=pool_padding,
global_pooling=global_pooling,
use_cudnn=use_cudnn)
def forward(self, inputs):
x = self._conv2d(inputs)
x = self._pool2d(x)
return x
class MNIST(fluid.dygraph.Layer):
def __init__(self, name_scope):
super(MNIST, self).__init__(name_scope)
self._simple_img_conv_pool_1 = SimpleImgConvPool(
self.full_name(), 1, 20, 5, 2, 2, act="relu")
self._simple_img_conv_pool_2 = SimpleImgConvPool(
self.full_name(), 20, 50, 5, 2, 2, act="relu")
pool_2_shape = 50 * 4 * 4
SIZE = 10
scale = (2.0 / (pool_2_shape**2 * SIZE))**0.5
self._fc = FC(self.full_name(),
10,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.NormalInitializer(
loc=0.0, scale=scale)),
act="softmax")
def forward(self, inputs):
x = self._simple_img_conv_pool_1(inputs)
x = self._simple_img_conv_pool_2(x)
x = self._fc(x)
return x
class TestDygraphMultiForward(unittest.TestCase):
def test_mnist_forward_float32(self):
seed = 90
epoch_num = 1
with fluid.dygraph.guard():
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed
mnist = MNIST("mnist")
sgd = SGDOptimizer(learning_rate=1e-3)
train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=128, drop_last=True)
dy_param_init_value = {}
mnist.eval()
for epoch in range(epoch_num):
for batch_id, data in enumerate(train_reader()):
dy_x_data = np.array(
[x[0].reshape(1, 28, 28)
for x in data]).astype('float32')
y_data = np.array(
[x[1] for x in data]).astype('int64').reshape(128, 1)
img = to_variable(dy_x_data)
label = to_variable(y_data)
label.stop_gradient = True
cost = mnist(img)
loss = fluid.layers.cross_entropy(cost, label)
avg_loss = fluid.layers.mean(loss)
dy_out = avg_loss.numpy()
if epoch == 0 and batch_id == 0:
for param in mnist.parameters():
dy_param_init_value[param.name] = param.numpy()
with new_program_scope():
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed
exe = fluid.Executor(fluid.CPUPlace(
) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0))
mnist = MNIST("mnist")
sgd = SGDOptimizer(learning_rate=1e-3)
train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=128, drop_last=True)
img = fluid.layers.data(
name='pixel', shape=[1, 28, 28], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
cost = mnist(img)
loss = fluid.layers.cross_entropy(cost, label)
avg_loss = fluid.layers.mean(loss)
# initialize params and fetch them
static_param_init_value = {}
static_param_name_list = []
for param in mnist.parameters():
static_param_name_list.append(param.name)
out = exe.run(fluid.default_startup_program(),
fetch_list=static_param_name_list)
for i in range(len(static_param_name_list)):
static_param_init_value[static_param_name_list[i]] = out[i]
for epoch in range(epoch_num):
for batch_id, data in enumerate(train_reader()):
static_x_data = np.array(
[x[0].reshape(1, 28, 28)
for x in data]).astype('float32')
y_data = np.array(
[x[1] for x in data]).astype('int64').reshape([128, 1])
fetch_list = [avg_loss.name]
out = exe.run(
fluid.default_main_program(),
feed={"pixel": static_x_data,
"label": y_data},
fetch_list=fetch_list)
static_out = out[0]
self.assertTrue(np.allclose(dy_x_data.all(), static_x_data.all()))
for key, value in six.iteritems(static_param_init_value):
self.assertTrue(np.allclose(value, dy_param_init_value[key]))
self.assertTrue(np.allclose(static_out, dy_out))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册