未验证 提交 a1d9a14e 编写于 作者: C Chen Weihang 提交者: GitHub

support grad accumulated across batch (#29942)

上级 bb20dcfc
...@@ -45,6 +45,7 @@ class GradientAccumulator { ...@@ -45,6 +45,7 @@ class GradientAccumulator {
inner_var_ = std::make_shared<VariableWrapper>(var->Name()); inner_var_ = std::make_shared<VariableWrapper>(var->Name());
inner_var_->SetType(var->Type()); inner_var_->SetType(var->Type());
inner_var_->SetDataType(var->DataType()); inner_var_->SetDataType(var->DataType());
inner_var_->SetForwardDataType(var->ForwardDataType());
inner_var_->InnerSetOverridedStopGradient( inner_var_->InnerSetOverridedStopGradient(
var->InnerOverridedStopGradient()); var->InnerOverridedStopGradient());
VLOG(6) << " Create inner grad var for (" << var->Name() VLOG(6) << " Create inner grad var for (" << var->Name()
......
...@@ -41,7 +41,6 @@ class Optimization_ex1(paddle.nn.Layer): ...@@ -41,7 +41,6 @@ class Optimization_ex1(paddle.nn.Layer):
np.random.random((4, 4)).astype(dtype) + np.random.random( np.random.random((4, 4)).astype(dtype) + np.random.random(
(4, 4)).astype(dtype) * 1j, (4, 4)).astype(dtype) * 1j,
stop_gradient=False) stop_gradient=False)
print(self.A)
def forward(self, mode=1): def forward(self, mode=1):
jj = paddle.to_tensor(np.array([1j]).astype(np.complex64)) jj = paddle.to_tensor(np.array([1j]).astype(np.complex64))
...@@ -70,31 +69,55 @@ class TestComplexGradAccumulated(unittest.TestCase): ...@@ -70,31 +69,55 @@ class TestComplexGradAccumulated(unittest.TestCase):
self.devices = ['cpu'] self.devices = ['cpu']
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
self.devices.append('gpu') self.devices.append('gpu')
self.iter = 3
self.learning_rate = 0.5
self.dtypes = ['float32', 'float64'] self.dtypes = ['float32', 'float64']
self.theta_size = [4, 4] self.theta_size = [4, 4]
def run_backward(self, device, dtype, mode): def train(self, device, dtype, mode):
paddle.set_device(device) paddle.set_device(device)
myLayer = Optimization_ex1(self.theta_size, dtype) myLayer = Optimization_ex1(self.theta_size, dtype)
optimizer = paddle.optimizer.SGD(learning_rate=self.learning_rate,
parameters=myLayer.parameters())
for iter in range(self.iter):
loss = myLayer(mode) loss = myLayer(mode)
loss.backward() loss.backward()
optimizer.step()
optimizer.clear_grad()
def train_no_clear_grad(self, device, dtype, mode):
paddle.set_device(device)
myLayer = Optimization_ex1(self.theta_size, dtype)
optimizer = paddle.optimizer.SGD(learning_rate=self.learning_rate,
parameters=myLayer.parameters())
for iter in range(self.iter):
loss = myLayer(mode)
loss.backward()
optimizer.step()
def test_case_one_step(self): def test_case_one_step(self):
for dev in self.devices: for dev in self.devices:
for dtype in self.dtypes: for dtype in self.dtypes:
self.run_backward(dev, dtype, 1) self.train(dev, dtype, 1)
self.train_no_clear_grad(dev, dtype, 1)
def test_case_two_step(self): def test_case_two_step(self):
for dev in self.devices: for dev in self.devices:
for dtype in self.dtypes: for dtype in self.dtypes:
self.run_backward(dev, dtype, 2) self.train(dev, dtype, 2)
self.train_no_clear_grad(dev, dtype, 2)
def test_case_non_param(self): def test_case_non_param(self):
for dev in self.devices: for dev in self.devices:
for dtype in self.dtypes: for dtype in self.dtypes:
self.run_backward(dev, dtype, 3) self.train(dev, dtype, 3)
self.train_no_clear_grad(dev, dtype, 3)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册