From 3ce4d34d01b24eeba4fc6e9e0a524181a993d031 Mon Sep 17 00:00:00 2001 From: 123malin Date: Fri, 8 Jan 2021 14:12:44 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=902.0API=20CherryPick=E3=80=91LookAhead,?= =?UTF-8?q?=20ModelAverage,=20IndexSelect=20(#30205)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add Lookahead and ModelAverage Optimizer (#30004) * test=develop, add model_average and lookahead * Improve Index select cuda kernel (#30139) * test=develop, add index_select_cuda kernel --- paddle/fluid/operators/index_select_op.cu | 183 +++++- paddle/fluid/pybind/op_function_generator.cc | 3 + python/paddle/__init__.py | 1 + .../fluid/tests/unittests/test_lookahead.py | 146 +++++ .../tests/unittests/test_modelaverage.py | 209 +++++++ python/paddle/incubate/__init__.py | 6 +- python/paddle/incubate/optimizer/__init__.py | 18 + python/paddle/incubate/optimizer/lookahead.py | 296 ++++++++++ .../paddle/incubate/optimizer/modelaverage.py | 525 ++++++++++++++++++ python/setup.py.in | 1 + 10 files changed, 1378 insertions(+), 10 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_lookahead.py create mode 100644 python/paddle/fluid/tests/unittests/test_modelaverage.py create mode 100644 python/paddle/incubate/optimizer/__init__.py create mode 100644 python/paddle/incubate/optimizer/lookahead.py create mode 100644 python/paddle/incubate/optimizer/modelaverage.py diff --git a/paddle/fluid/operators/index_select_op.cu b/paddle/fluid/operators/index_select_op.cu index 36a91d98a2..752e8b277d 100644 --- a/paddle/fluid/operators/index_select_op.cu +++ b/paddle/fluid/operators/index_select_op.cu @@ -12,18 +12,185 @@ // See the License for the specific language governing permissions and // limitations under the License. +#pragma once +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/index_select_op.h" +#include "paddle/fluid/platform/cuda_primitives.h" + +namespace paddle { +namespace operators { + +using platform::PADDLE_CUDA_NUM_THREADS; +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; + +template +__global__ void index_select_cuda_kernel(const T* input, T* output, + const IndexT* index, int64_t N, + int64_t stride, int64_t size, + int64_t delta) { + int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= N) { + return; + } + + int64_t pre_idx = idx / (stride * size); + int64_t dim_idx = idx % (stride * size) / stride; + IndexT src_dim_idx = index[dim_idx]; + int64_t input_idx = idx + (delta * pre_idx + src_dim_idx - dim_idx) * stride; + output[idx] = input[input_idx]; +} + +template +__global__ void index_select_grad_cuda_kernel(const T* output_grad, + T* input_grad, + const IndexT* index, int64_t nums, + int64_t N, int64_t stride, + int64_t size, int64_t delta) { + int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= N) { + return; + } + + int64_t pre_idx = idx / (stride * size); + int64_t dim_idx = idx % (stride * size) / stride; + int64_t begin_idx = idx + (delta * pre_idx - dim_idx) * stride; + + input_grad[idx] = 0.0; + for (int64_t i = 0; i < nums; i++) { + if (index[i] == dim_idx) { + input_grad[idx] += output_grad[begin_idx + i * stride]; + } + } +} + +template +class IndexSelectCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* in = context.Input("X"); + auto* index = context.Input("Index"); + auto* out = context.Output("Out"); + int dim = context.Attr("dim"); + auto input_dim = in->dims(); + auto output_dim = out->dims(); + dim = dim >= 0 ? dim : dim + input_dim.size(); + auto stride_dim = framework::stride(input_dim); + int64_t stride = stride_dim[dim]; + int64_t size = output_dim[dim]; + int64_t delta = input_dim[dim] - size; + + const auto& index_type = index->type(); + bool index_type_match = index_type == framework::proto::VarType::INT64 || + index_type == framework::proto::VarType::INT32; + PADDLE_ENFORCE_EQ(index_type_match, true, + platform::errors::InvalidArgument( + "Input(Index) holds the wrong type, it holds %s, but " + "desires to be %s or %s", + paddle::framework::DataTypeToString(index_type), + paddle::framework::DataTypeToString( + framework::proto::VarType::INT32), + paddle::framework::DataTypeToString( + framework::proto::VarType::INT64))); + + auto* in_data = in->data(); + auto* out_data = out->mutable_data(context.GetPlace()); + int64_t numel = out->numel(); + + auto stream = + context.template device_context().stream(); + + if (index_type == framework::proto::VarType::INT64) { + const int64_t* index_data = index->data(); + index_select_cuda_kernel<<< + (numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS, + PADDLE_CUDA_NUM_THREADS, 0, stream>>>(in_data, out_data, index_data, + numel, stride, size, delta); + PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream)); + } else { + const int* index_data = index->data(); + index_select_cuda_kernel<<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / + PADDLE_CUDA_NUM_THREADS, + PADDLE_CUDA_NUM_THREADS, 0, stream>>>( + in_data, out_data, index_data, numel, stride, size, delta); + PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream)); + } + } +}; + +template +class IndexSelectGradCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* output_grad = context.Input(framework::GradVarName("Out")); + auto* in_grad = context.Output(framework::GradVarName("X")); + auto* index = context.Input("Index"); + + auto* output_grad_data = output_grad->data(); + auto* in_grad_data = in_grad->mutable_data(context.GetPlace()); + + int dim = context.Attr("dim"); + auto input_dim = in_grad->dims(); + auto output_dim = output_grad->dims(); + dim = dim >= 0 ? dim : dim + input_dim.size(); + auto stride_dim = framework::stride(input_dim); + int64_t stride = stride_dim[dim]; + int64_t size = input_dim[dim]; + int64_t delta = output_dim[dim] - size; + + const auto& index_type = index->type(); + bool index_type_match = index_type == framework::proto::VarType::INT64 || + index_type == framework::proto::VarType::INT32; + PADDLE_ENFORCE_EQ(index_type_match, true, + platform::errors::InvalidArgument( + "Input(Index) holds the wrong type, it holds %s, but " + "desires to be %s or %s", + paddle::framework::DataTypeToString(index_type), + paddle::framework::DataTypeToString( + framework::proto::VarType::INT32), + paddle::framework::DataTypeToString( + framework::proto::VarType::INT64))); + + int64_t numel = in_grad->numel(); + int64_t index_nums = index->numel(); + + auto stream = + context.template device_context().stream(); + + if (index_type == framework::proto::VarType::INT64) { + const int64_t* index_data = index->data(); + index_select_grad_cuda_kernel<<< + (numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS, + PADDLE_CUDA_NUM_THREADS, 0, stream>>>(output_grad_data, in_grad_data, + index_data, index_nums, numel, + stride, size, delta); + PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream)); + } else { + const int* index_data = index->data(); + index_select_grad_cuda_kernel<<< + (numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS, + PADDLE_CUDA_NUM_THREADS, 0, stream>>>(output_grad_data, in_grad_data, + index_data, index_nums, numel, + stride, size, delta); + PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream)); + } + } +}; + +} // namespace operators +} // namespace paddle namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( index_select, - ops::IndexSelectKernel, - ops::IndexSelectKernel, - ops::IndexSelectKernel, - ops::IndexSelectKernel); + ops::IndexSelectCUDAKernel, + ops::IndexSelectCUDAKernel, + ops::IndexSelectCUDAKernel, + ops::IndexSelectCUDAKernel); REGISTER_OP_CUDA_KERNEL( index_select_grad, - ops::IndexSelectGradKernel, - ops::IndexSelectGradKernel, - ops::IndexSelectGradKernel, - ops::IndexSelectGradKernel); + ops::IndexSelectGradCUDAKernel, + ops::IndexSelectGradCUDAKernel, + ops::IndexSelectGradCUDAKernel, + ops::IndexSelectGradCUDAKernel); diff --git a/paddle/fluid/pybind/op_function_generator.cc b/paddle/fluid/pybind/op_function_generator.cc index 07218b8f3e..38e2eb4e7c 100644 --- a/paddle/fluid/pybind/op_function_generator.cc +++ b/paddle/fluid/pybind/op_function_generator.cc @@ -101,6 +101,9 @@ std::map> op_passing_outs_map = { {"sgd", {"ParamOut"}}, {"adam", {"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut"}}, + {"average_accumulates", + {"out_sum_1", "out_sum_2", "out_sum_3", "out_num_accumulates", + "out_old_num_accumulates", "out_num_updates"}}, {"momentum", {"ParamOut", "VelocityOut"}}, {"batch_norm", {"MeanOut", "VarianceOut"}}, {"sync_batch_norm", {"MeanOut", "VarianceOut"}}, diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 1d5f4cc1df..74c8695897 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -43,6 +43,7 @@ import paddle.optimizer import paddle.metric import paddle.device import paddle.regularizer +import paddle.incubate # TODO: define alias in tensor and framework directory diff --git a/python/paddle/fluid/tests/unittests/test_lookahead.py b/python/paddle/fluid/tests/unittests/test_lookahead.py new file mode 100644 index 0000000000..98349be93d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_lookahead.py @@ -0,0 +1,146 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest +from paddle.fluid import core +from paddle.fluid.op import Operator +import paddle.fluid as fluid +import paddle +import paddle.nn as nn + +LOOKAHEAD_K = 5 +LOOKAHEAD_ALPHA = 0.2 +SGD_LR = 1.0 + + +class TestLookAhead(unittest.TestCase): + def test_lookahead_static(self): + paddle.enable_static() + place = fluid.CPUPlace() + shape = [2, 3, 8, 8] + exe = fluid.Executor(place) + train_program = fluid.Program() + startup = fluid.Program() + with fluid.program_guard(train_program, startup): + with fluid.unique_name.guard(): + data = fluid.data(name='X', shape=[None, 1], dtype='float32') + hidden = fluid.layers.fc(input=data, size=10) + loss = fluid.layers.mean(hidden) + + optimizer = paddle.optimizer.SGD(learning_rate=SGD_LR) + lookahead = paddle.incubate.optimizer.LookAhead( + optimizer, alpha=LOOKAHEAD_ALPHA, k=LOOKAHEAD_K) + lookahead.minimize(loss) + + exe.run(startup) + slow_param = None + fast_param = None + for i in range(10): + if (i + 1) % LOOKAHEAD_K == 0: + slow_param = slow_param + LOOKAHEAD_ALPHA * (fast_param - + slow_param) + x = np.random.random(size=(10, 1)).astype('float32') + latest_b, b_grad = exe.run(program=train_program, + feed={'X': x}, + fetch_list=[ + 'fc_0.b_0', + 'fc_0.b_0@GRAD', + ]) + if i == 0: + slow_param = latest_b + if (i + 1) % LOOKAHEAD_K == 0: + self.assertAlmostEqual( + slow_param.all(), latest_b.all(), delta=5e-3) + fast_param = latest_b - SGD_LR * b_grad + + def test_look_ahead_dygraph(self): + BATCH_SIZE = 16 + BATCH_NUM = 4 + EPOCH_NUM = 4 + + IMAGE_SIZE = 784 + CLASS_NUM = 10 + + # define a random dataset + class RandomDataset(paddle.io.Dataset): + def __init__(self, num_samples): + self.num_samples = num_samples + + def __getitem__(self, idx): + image = np.random.random([IMAGE_SIZE]).astype('float32') + label = np.random.randint(0, CLASS_NUM - 1, + (1, )).astype('int64') + return image, label + + def __len__(self): + return self.num_samples + + class LinearNet(nn.Layer): + def __init__(self): + super(LinearNet, self).__init__() + self._linear = nn.Linear(IMAGE_SIZE, CLASS_NUM) + self.bias = self._linear.bias + + @paddle.jit.to_static + def forward(self, x): + return self._linear(x) + + def train(layer, loader, loss_fn, opt): + idx = 0 + slow_param = None + fast_param = None + for epoch_id in range(EPOCH_NUM): + for batch_id, (image, label) in enumerate(loader()): + idx += 1 + out = layer(image) + loss = loss_fn(out, label) + loss.backward() + fast_param = layer.bias.numpy() - SGD_LR * layer.bias.grad + opt.step() + if idx == 1: + slow_param = fast_param + if idx % LOOKAHEAD_K == 0: + slow_param = slow_param + LOOKAHEAD_ALPHA * ( + fast_param - slow_param) + self.assertAlmostEqual( + np.mean(slow_param), + np.mean(layer.bias.numpy()), + delta=5e-3) + opt.clear_grad() + + layer = LinearNet() + loss_fn = nn.CrossEntropyLoss() + optimizer = paddle.optimizer.SGD(learning_rate=SGD_LR, + parameters=layer.parameters()) + lookahead = paddle.incubate.optimizer.LookAhead( + optimizer, alpha=LOOKAHEAD_ALPHA, k=LOOKAHEAD_K) + + # create data loader + dataset = RandomDataset(BATCH_NUM * BATCH_SIZE) + loader = paddle.io.DataLoader( + dataset, + batch_size=BATCH_SIZE, + shuffle=True, + drop_last=True, + num_workers=2) + + train(layer, loader, loss_fn, lookahead) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_modelaverage.py b/python/paddle/fluid/tests/unittests/test_modelaverage.py new file mode 100644 index 0000000000..8dab35f7f5 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_modelaverage.py @@ -0,0 +1,209 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest +from paddle.fluid import core +from paddle.fluid.op import Operator +import paddle.fluid as fluid +import paddle +import paddle.nn as nn + + +class TestModelAverage(unittest.TestCase): + def test_model_average_static(self): + paddle.enable_static() + place = fluid.CPUPlace() + shape = [2, 3, 8, 8] + exe = fluid.Executor(place) + train_program = fluid.Program() + startup = fluid.Program() + test_program = fluid.Program() + with fluid.program_guard(train_program, startup): + with fluid.unique_name.guard(): + data = fluid.data(name='X', shape=[None, 1], dtype='float32') + hidden = fluid.layers.fc(input=data, size=10) + loss = fluid.layers.mean(hidden) + test_program = train_program.clone() + optimizer = paddle.optimizer.Momentum( + learning_rate=0.2, momentum=0.1) + + optimizer.minimize(loss) + # build ModelAverage optimizer + model_average = paddle.incubate.optimizer.ModelAverage( + 0.15, min_average_window=2, max_average_window=10) + + exe.run(startup) + for i in range(10): + x = np.random.random(size=(10, 1)).astype('float32') + latest_b, sum_1, sum_2, sum_3, num_accumulates, old_num_accumulates, num_updates = exe.run( + program=train_program, + feed={'X': x}, + fetch_list=[ + 'fc_0.b_0', 'fc_0.b_0_sum_1_0', 'fc_0.b_0_sum_2_0', + 'fc_0.b_0_sum_3_0', 'fc_0.b_0_num_accumulates_0', + 'fc_0.b_0_old_num_accumulates_0', 'fc_0.b_0_num_updates_0' + ]) + self.assertTrue( + np.equal( + sum_1, np.zeros( + shape=[10], dtype='float32')).all()) + self.assertTrue( + np.equal( + sum_2, np.zeros( + shape=[10], dtype='float32')).all()) + self.assertTrue( + np.equal( + num_accumulates, np.array( + [0], dtype='int64')).all()) + self.assertTrue( + np.equal( + old_num_accumulates, np.array( + [2], dtype='int64')).all()) + self.assertTrue( + np.equal( + num_updates, np.array( + [10], dtype='int64')).all()) + + average_b = (sum_1 + sum_2 + sum_3) / ( + num_accumulates + old_num_accumulates) + # apply ModelAverage + with model_average.apply(exe): + x = np.random.random(size=(10, 1)).astype('float32') + outs, b = exe.run(program=test_program, + feed={'X': x}, + fetch_list=[loss.name, 'fc_0.b_0']) + self.assertAlmostEqual(np.mean(average_b), np.mean(b)) + + x = np.random.random(size=(10, 1)).astype('float32') + outs, b = exe.run(program=test_program, + feed={'X': x}, + fetch_list=[loss.name, 'fc_0.b_0']) + self.assertAlmostEqual(np.mean(latest_b), np.mean(b)) + + def test_model_average_dygraph(self): + BATCH_SIZE = 16 + BATCH_NUM = 4 + EPOCH_NUM = 4 + + IMAGE_SIZE = 784 + CLASS_NUM = 10 + + # define a random dataset + class RandomDataset(paddle.io.Dataset): + def __init__(self, num_samples): + self.num_samples = num_samples + + def __getitem__(self, idx): + image = np.random.random([IMAGE_SIZE]).astype('float32') + label = np.random.randint(0, CLASS_NUM - 1, + (1, )).astype('int64') + return image, label + + def __len__(self): + return self.num_samples + + class LinearNet(nn.Layer): + def __init__(self): + super(LinearNet, self).__init__() + self._linear = nn.Linear(IMAGE_SIZE, CLASS_NUM) + self.bias = self._linear.bias + + @paddle.jit.to_static + def forward(self, x): + return self._linear(x) + + def train(layer, loader, loss_fn, opt, model_average): + for epoch_id in range(EPOCH_NUM): + for batch_id, (image, label) in enumerate(loader()): + out = layer(image) + loss = loss_fn(out, label) + loss.backward() + opt.step() + model_average.step() + opt.clear_grad() + model_average.clear_grad() + # print("Train Epoch {} batch {}: loss = {}, bias = {}".format( + # epoch_id, batch_id, np.mean(loss.numpy()), layer.bias.numpy())) + sum_1 = model_average._get_accumulator('sum_1', layer.bias) + sum_2 = model_average._get_accumulator('sum_2', layer.bias) + sum_3 = model_average._get_accumulator('sum_3', layer.bias) + num_accumulates = model_average._get_accumulator('num_accumulates', + layer.bias) + old_num_accumulates = model_average._get_accumulator( + 'old_num_accumulates', layer.bias) + num_updates = model_average._get_accumulator('num_updates', + layer.bias) + + return ((sum_1 + sum_2 + sum_3) / + (num_accumulates + old_num_accumulates)).numpy() + + def evaluate(layer, loader, loss_fn, check_param): + for batch_id, (image, label) in enumerate(loader()): + out = layer(image) + loss = loss_fn(out, label) + loss.backward() + self.assertAlmostEqual( + np.mean(layer.bias.numpy()), + np.mean(check_param), + delta=5e-3) + # print("Evaluate batch {}: loss = {}, bias = {}".format( + # batch_id, np.mean(loss.numpy()), layer.bias.numpy())) + + # create network + + layer = LinearNet() + loss_fn = nn.CrossEntropyLoss() + optimizer = paddle.optimizer.Momentum( + learning_rate=0.2, momentum=0.1, parameters=layer.parameters()) + # build ModelAverage optimizer + model_average = paddle.incubate.optimizer.ModelAverage( + 0.15, + parameters=layer.parameters(), + min_average_window=2, + max_average_window=10) + + # create data loader + dataset = RandomDataset(BATCH_NUM * BATCH_SIZE) + loader = paddle.io.DataLoader( + dataset, + batch_size=BATCH_SIZE, + shuffle=True, + drop_last=True, + num_workers=2) + eval_loader = paddle.io.DataLoader( + dataset, + batch_size=BATCH_SIZE, + shuffle=True, + drop_last=True, + num_workers=1) + # train + check_param = train(layer, loader, loss_fn, optimizer, model_average) + # print(check_param) + with model_average.apply(need_restore=False): + evaluate(layer, eval_loader, loss_fn, check_param) + + check_param = (model_average._get_accumulator('restore', + layer.bias)).numpy() + # print(check_param) + # print("\nEvaluate With Restored Paramters") + model_average.restore() + evaluate(layer, eval_loader, loss_fn, check_param) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/incubate/__init__.py b/python/paddle/incubate/__init__.py index 2af9255971..f7c3b00d02 100644 --- a/python/paddle/incubate/__init__.py +++ b/python/paddle/incubate/__init__.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from . import optimizer +from ..fluid.contrib import reader + __all__ = [] __all__ += ["reader"] - -from ..fluid.contrib import reader +__all__ += optimizer.__all__ diff --git a/python/paddle/incubate/optimizer/__init__.py b/python/paddle/incubate/optimizer/__init__.py new file mode 100644 index 0000000000..4a3889d0ee --- /dev/null +++ b/python/paddle/incubate/optimizer/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .lookahead import LookAhead +from .modelaverage import ModelAverage + +__all__ = ['LookAhead', 'ModelAverage'] diff --git a/python/paddle/incubate/optimizer/lookahead.py b/python/paddle/incubate/optimizer/lookahead.py new file mode 100644 index 0000000000..3dca25c2bf --- /dev/null +++ b/python/paddle/incubate/optimizer/lookahead.py @@ -0,0 +1,296 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.optimizer import Optimizer +from paddle.fluid import core, framework, layers, unique_name +from paddle.fluid.framework import Program, Variable, name_scope, default_main_program, default_startup_program, device_guard +from paddle.fluid.layer_helper import LayerHelper +import paddle +import numpy as np +from paddle.fluid.dygraph import base as imperative_base + +__all__ = ["LookAhead"] + + +class LookAhead(Optimizer): + r""" + This implements the Lookahead optimizer of the + paper : https://arxiv.org/abs/1907.08610. + + Lookahead keeps two sets of params: the fast_params and + the slow_params. inner_optimizer update fast_params every + training step. Lookahead updates the slow_params and fast_params + every k training steps as follows: + + .. math:: + + slow\_param_t &= slow\_param_{t-1} + \\alpha * (fast\_param_{t-1} - slow\_param_{t-1}) + + fast\_param_t &= slow\_param_t + + Args: + inner_optimizer (Optimizer): The optimizer that update fast params step by step. + alpha (float, optinal): The learning rate of Lookahead. The default value is 0.5. + k (int, optinal): The slow params is updated every k steps. The default value is 5. + name (str, optional): Normally there is no need for user to set this property. + For more information, please refer to :ref:`api_guide_Name`. + The default value is None. + + Examples: + + .. code-block:: python + + import numpy as np + import paddle + import paddle.nn as nn + + BATCH_SIZE = 16 + BATCH_NUM = 4 + EPOCH_NUM = 4 + + IMAGE_SIZE = 784 + CLASS_NUM = 10 + # define a random dataset + class RandomDataset(paddle.io.Dataset): + def __init__(self, num_samples): + self.num_samples = num_samples + + def __getitem__(self, idx): + image = np.random.random([IMAGE_SIZE]).astype('float32') + label = np.random.randint(0, CLASS_NUM - 1, + (1, )).astype('int64') + return image, label + + def __len__(self): + return self.num_samples + + class LinearNet(nn.Layer): + def __init__(self): + super(LinearNet, self).__init__() + self._linear = nn.Linear(IMAGE_SIZE, CLASS_NUM) + self.bias = self._linear.bias + + @paddle.jit.to_static + def forward(self, x): + return self._linear(x) + + def train(layer, loader, loss_fn, opt): + for epoch_id in range(EPOCH_NUM): + for batch_id, (image, label) in enumerate(loader()): + out = layer(image) + loss = loss_fn(out, label) + loss.backward() + opt.step() + opt.clear_grad() + print("Train Epoch {} batch {}: loss = {}".format( + epoch_id, batch_id, np.mean(loss.numpy()))) + + layer = LinearNet() + loss_fn = nn.CrossEntropyLoss() + optimizer = paddle.optimizer.SGD(learning_rate=0.1, parameters=layer.parameters()) + lookahead = paddle.incubate.optimizer.LookAhead(optimizer, alpha=0.2, k=5) + + # create data loader + dataset = RandomDataset(BATCH_NUM * BATCH_SIZE) + loader = paddle.io.DataLoader( + dataset, + batch_size=BATCH_SIZE, + shuffle=True, + drop_last=True, + num_workers=2) + + train(layer, loader, loss_fn, lookahead) + + """ + _slow_str = "slow" + + def __init__(self, inner_optimizer, alpha=0.5, k=5, name=None): + assert (inner_optimizer is not None), "inner optimizer can not be None" + assert ( + 0.0 <= alpha <= 1.0 + ), "alpha should be larger or equal to 0.0, and less or equal than 1.0" + assert (isinstance(k, int) and k > 0), "k should be a positive integer" + + self.inner_optimizer = inner_optimizer + if self.inner_optimizer._parameter_list is None: + parameters = framework.default_main_program().global_block( + ).all_parameters() + else: + parameters = self.inner_optimizer._parameter_list + + super(LookAhead, self).__init__( + learning_rate=alpha, + parameters=parameters, + weight_decay=None, + grad_clip=None, + name=name) + + self.alpha = alpha + self.k = k + self.type = "lookahead" + self.helper = LayerHelper(self.__class__.__name__) + self._global_step_var = None + self._k_var = None + + @framework.dygraph_only + @imperative_base.no_grad + def step(self): + """ + Execute the optimizer and update parameters once. + + Returns: + None + + Examples: + + .. code-block:: python + + import paddle + import numpy as np + inp = paddle.to_tensor(np.random.random([1, 10]).astype('float32')) + linear = paddle.nn.Linear(10, 1) + out = linear(inp) + loss = paddle.mean(out) + sgd = paddle.optimizer.SGD(learning_rate=0.1,parameters=linear.parameters()) + lookahead = paddle.incubate.optimizer.LookAhead(sgd, alpha=0.2, k=5) + loss.backward() + lookahead.step() + lookahead.clear_grad() + + """ + self.inner_optimizer.step() + + params_grads = [] + for param in self._parameter_list: + if not param.trainable: + continue + if param._grad_ivar() is not None: + grad_var = param._grad_ivar() + params_grads.append((param, grad_var)) + + self._apply_optimize( + loss=None, startup_program=None, params_grads=params_grads) + + def _create_accumulators(self, block, parameters): + assert isinstance(block, framework.Block) + + for p in parameters: + self._add_accumulator(self._slow_str, p) + + def _append_optimize_op(self, block, param_and_grad): + if self._global_step_var is None: + self._global_step_var = layers.create_global_var( + name=unique_name.generate("lookahead_step"), + shape=[1], + value=0, + dtype='int32', + persistable=True) + + self.helper.append_op( + type='increment', + inputs={'X': [self._global_step_var]}, + outputs={'Out': [self._global_step_var]}, + attrs={'step': 1.0}) + + one_var = paddle.ones(shape=[1], dtype='int32', name='lookahead_ones') + zero_var = paddle.zeros( + shape=[1], dtype='int32', name='lookahead_zeros') + k_var = layers.create_global_var( + name=unique_name.generate("lookahead_k"), + shape=[1], + value=self.k, + dtype='int32', + persistable=True) + + mod = paddle.remainder(self._global_step_var, k_var) + + cond_1 = paddle.equal(self._global_step_var, one_var) + cond_1 = paddle.cast(cond_1, dtype='float32') + + cond_2 = paddle.equal(mod, zero_var) + cond_2 = paddle.cast(cond_2, dtype='float32') + + slow_var = self._get_accumulator(self._slow_str, param_and_grad[0]) + + tmp_var = cond_1 * param_and_grad[0] + (1 - cond_1) * slow_var + paddle.assign(tmp_var, slow_var) + + tmp_var = self.alpha * param_and_grad[0] + (1.0 - self.alpha) * slow_var + tmp_var_1 = cond_2 * tmp_var + (1 - cond_2) * param_and_grad[0] + paddle.assign(tmp_var_1, param_and_grad[0]) + + tmp_var_1 = cond_2 * tmp_var + (1 - cond_2) * slow_var + paddle.assign(tmp_var_1, slow_var) + + @imperative_base.no_grad + def minimize(self, + loss, + startup_program=None, + parameters=None, + no_grad_set=None): + """ + Add operations to minimize ``loss`` by updating ``parameters``. + + Args: + loss (Tensor): A ``Tensor`` containing the value to minimize. + startup_program (Program, optional): :ref:`api_fluid_Program` for + initializing parameters in ``parameters``. The default value + is None, at this time :ref:`api_fluid_default_startup_program` will be used. + parameters (list, optional): List of ``Tensor`` or ``Tensor.name`` to update + to minimize ``loss``. The default value is None, at this time all parameters + will be updated. + no_grad_set (set, optional): Set of ``Tensor`` or ``Tensor.name`` that don't need + to be updated. The default value is None. + + Returns: + tuple: tuple (optimize_ops, params_grads), A list of operators appended + by minimize and a list of (param, grad) tensor pairs, param is + ``Parameter``, grad is the gradient value corresponding to the parameter. + In static graph mode, the returned tuple can be passed to ``fetch_list`` in ``Executor.run()`` to + indicate program pruning. If so, the program will be pruned by ``feed`` and + ``fetch_list`` before run, see details in ``Executor``. + + Examples: + + .. code-block:: python + + import paddle + import numpy as np + inp = paddle.to_tensor(np.random.random([1, 10]).astype('float32')) + linear = paddle.nn.Linear(10, 1) + out = linear(inp) + loss = paddle.mean(out) + sgd = paddle.optimizer.SGD(learning_rate=0.1,parameters=linear.parameters()) + lookahead = paddle.incubate.optimizer.LookAhead(sgd, alpha=0.2, k=5) + loss.backward() + lookahead.minimize(loss) + lookahead.clear_grad() + + """ + assert isinstance(loss, Variable), "The loss should be an Tensor." + + parameter_list = parameters if parameters \ + else self._parameter_list + + # Apply inner optimizer to the main_program + optimize_ops, params_grads = self.inner_optimizer.minimize( + loss, + startup_program=startup_program, + parameters=parameters, + no_grad_set=no_grad_set) + + _ = self._apply_optimize( + loss, startup_program=startup_program, params_grads=params_grads) + + return optimize_ops, params_grads diff --git a/python/paddle/incubate/optimizer/modelaverage.py b/python/paddle/incubate/optimizer/modelaverage.py new file mode 100644 index 0000000000..8afcaf9207 --- /dev/null +++ b/python/paddle/incubate/optimizer/modelaverage.py @@ -0,0 +1,525 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.optimizer import Optimizer +from paddle.fluid import core, framework, layers +from paddle.fluid.framework import Program, Variable +from paddle.fluid.layer_helper import LayerHelper +import paddle +import numpy as np +from paddle.fluid.dygraph import base as imperative_base +from paddle.fluid.wrapped_decorator import signature_safe_contextmanager + +__all__ = ["ModelAverage"] + + +class ModelAverage(Optimizer): + r""" + The ModelAverage optimizer accumulates specific continuous historical + parameters during training. The accumulated historical range can be controlled + by the passed ``average_window_rate`` argument. The averaged ``Parameter`` are + used in the prediction, which usually can improve the accuracy of the prediction. + + Accumulate the average of the ``Parameter`` in the sliding window, the result will be saved + in a temporary variable, can be applied to the current model's ``Parameter`` by calling + the ``apply()`` method, and the current model ``Parameter`` can be restored by calling + the ``restore()`` method. + + The window size for calculating the average is determined by ``average_window_rate``, + ``min_average_window``, ``max_average_window`` and the current ``Parameter`` update times (num_updates). + + When the cumulative times (num_accumulates) is greater than the specific window + threshold (average_window), the accumulated ``Parameter`` temporary variable is set to 0.0. + The following example will help to understand the role of these arguments: + + :: + + if num_accumulates >= min_average_window and num_accumulates >= min(max_average_window, num_updates * average_window_rate): + num_accumulates = 0 + + In the above conditional judgment statement, ``num_accumulates`` indicates the current + accumulated number, which can be abstractly understood as the length of the cumulative window. + The length of the window must be at least the length set by the ``min_average_window`` argument, + and cannot exceed the length specified by the ``max_average_window`` argument or + ``num_updates * average_window_rate``, where ``num_updates`` indicates the current ``Parameter`` + update times, ``average_window_rate`` is a coefficient that calculates the length of the window. + + Args: + average_window_rate (float): The calculate ratio of the window length relative to ``Parameter`` update times. + parameters (list, optional): List of ``Tensor`` names to update to minimize ``loss``. \ + This parameter is required in dygraph mode. \ + The default value is None in static mode, at this time all parameters will be updated. + min_average_window (int, optional): the minimum size of average window length. The default value is 10000. + max_average_window (int, optional): The maximum size of average window length. The default value is 10000. + name (str, optional): Normally there is no need for user to set this property. + For more information, please refer to :ref:`api_guide_Name`. + The default value is None. + + Examples: + + .. code-block:: python + + import numpy as np + import paddle + import paddle.nn as nn + import paddle.optimizer as opt + + BATCH_SIZE = 16 + BATCH_NUM = 4 + EPOCH_NUM = 4 + + IMAGE_SIZE = 784 + CLASS_NUM = 10 + + # define a random dataset + class RandomDataset(paddle.io.Dataset): + def __init__(self, num_samples): + self.num_samples = num_samples + + def __getitem__(self, idx): + image = np.random.random([IMAGE_SIZE]).astype('float32') + label = np.random.randint(0, CLASS_NUM - 1, (1, )).astype('int64') + return image, label + + def __len__(self): + return self.num_samples + + class LinearNet(nn.Layer): + def __init__(self): + super(LinearNet, self).__init__() + self._linear = nn.Linear(IMAGE_SIZE, CLASS_NUM) + self.bias = self._linear.bias + + @paddle.jit.to_static + def forward(self, x): + return self._linear(x) + + def train(layer, loader, loss_fn, opt, model_average): + for epoch_id in range(EPOCH_NUM): + for batch_id, (image, label) in enumerate(loader()): + out = layer(image) + loss = loss_fn(out, label) + loss.backward() + opt.step() + model_average.step() + opt.clear_grad() + model_average.clear_grad() + print("Train Epoch {} batch {}: loss = {}, bias = {}".format( + epoch_id, batch_id, np.mean(loss.numpy()), layer.bias.numpy())) + def evaluate(layer, loader, loss_fn): + for batch_id, (image, label) in enumerate(loader()): + out = layer(image) + loss = loss_fn(out, label) + loss.backward() + print("Evaluate batch {}: loss = {}, bias = {}".format( + batch_id, np.mean(loss.numpy()), layer.bias.numpy())) + + # create network + layer = LinearNet() + loss_fn = nn.CrossEntropyLoss() + optimizer = opt.Momentum(learning_rate=0.2, momentum=0.1, parameters=layer.parameters()) + model_average = paddle.incubate.optimizer.ModelAverage(0.15, + parameters=layer.parameters(), + min_average_window=2, + max_average_window=10) + + # create data loader + dataset = RandomDataset(BATCH_NUM * BATCH_SIZE) + loader = paddle.io.DataLoader(dataset, + batch_size=BATCH_SIZE, + shuffle=True, + drop_last=True, + num_workers=2) + # create data loader + eval_loader = paddle.io.DataLoader(dataset, + batch_size=BATCH_SIZE, + shuffle=True, + drop_last=True, + num_workers=1) + + # train + train(layer, loader, loss_fn, optimizer, model_average) + + print("\nEvaluate With ModelAverage") + with model_average.apply(need_restore=False): + evaluate(layer, eval_loader, loss_fn) + + print("\nEvaluate With Restored Paramters") + model_average.restore() + evaluate(layer, eval_loader, loss_fn) + + """ + + def __init__(self, + average_window_rate, + parameters=None, + min_average_window=10000, + max_average_window=10000, + name=None): + super(ModelAverage, self).__init__( + learning_rate=0.0, + parameters=parameters, + weight_decay=None, + grad_clip=None, + name=name) + + self.helper = LayerHelper(self.__class__.__name__) + self.average_window = average_window_rate + self.min_average_window = min_average_window + self.max_average_window = max_average_window + self.type = "average_accumulates" + + if not framework.in_dygraph_mode(): + global_block = framework.default_main_program().global_block() + all_parameters = parameters if parameters else global_block.all_parameters( + ) + + self._create_accumulators(global_block, all_parameters) + for param in all_parameters: + self._append_optimize_op(global_block, [param, None]) + self.apply_program = Program() + block = self.apply_program.global_block() + with framework.program_guard(main_program=self.apply_program): + for param in all_parameters: + self._add_average_apply_op(block, param) + self.restore_program = Program() + block = self.restore_program.global_block() + with framework.program_guard(main_program=self.restore_program): + for param in all_parameters: + self._add_average_restore_op(block, param) + + def _create_accumulators(self, block, parameters): + assert isinstance(block, framework.Block) + + for param in parameters: + self._add_accumulator('sum_1', param) + self._add_accumulator('sum_2', param) + self._add_accumulator('sum_3', param) + self._add_accumulator('restore', param) + self._add_accumulator( + 'num_accumulates', param, dtype='int64', shape=[1]) + self._add_accumulator( + 'old_num_accumulates', param, dtype='int64', shape=[1]) + self._add_accumulator( + 'num_updates', param, dtype='int64', shape=[1]) + + def _append_optimize_op(self, block, param_and_grad): + assert isinstance(block, framework.Block) + + sum_1 = self._get_accumulator('sum_1', param_and_grad[0]) + sum_2 = self._get_accumulator('sum_2', param_and_grad[0]) + sum_3 = self._get_accumulator('sum_3', param_and_grad[0]) + num_accumulates = self._get_accumulator('num_accumulates', + param_and_grad[0]) + old_num_accumulates = self._get_accumulator('old_num_accumulates', + param_and_grad[0]) + num_updates = self._get_accumulator('num_updates', param_and_grad[0]) + if framework.in_dygraph_mode(): + _, _, _, _, _, _ = core.ops.average_accumulates( + param_and_grad[0], sum_1, sum_2, sum_3, num_accumulates, + old_num_accumulates, num_updates, sum_1, sum_2, sum_3, + num_accumulates, old_num_accumulates, num_updates, + 'average_window', self.average_window, 'min_average_window', + self.min_average_window, 'max_average_window', + self.max_average_window) + return None + + block = framework.default_main_program().global_block() + attrs = { + "average_window": self.average_window, + "min_average_window": self.min_average_window, + "max_average_window": self.max_average_window, + } + + inputs = { + "param": param_and_grad[0], + "in_sum_1": sum_1, + "in_sum_2": sum_2, + "in_sum_3": sum_3, + "in_num_accumulates": num_accumulates, + "in_old_num_accumulates": old_num_accumulates, + "in_num_updates": num_updates + } + + outputs = { + "out_sum_1": sum_1, + "out_sum_2": sum_2, + "out_sum_3": sum_3, + "out_num_accumulates": num_accumulates, + "out_old_num_accumulates": old_num_accumulates, + "out_num_updates": num_updates, + } + + average_accumulates_op = block.append_op( + type=self.type, + inputs=inputs, + outputs=outputs, + attrs=attrs, + stop_gradient=True) + + return average_accumulates_op + + @imperative_base.no_grad + def minimize(self, + loss, + startup_program=None, + parameters=None, + no_grad_set=None): + """ + Add operations to minimize ``loss`` by updating ``parameters``. + + Args: + loss (Tensor): A ``Tensor`` containing the value to minimize. + startup_program (Program, optional): :ref:`api_fluid_Program` for + initializing parameters in ``parameters``. The default value + is None, at this time :ref:`api_fluid_default_startup_program` will be used. + parameters (list, optional): List of ``Tensor`` or ``Tensor.name`` to update + to minimize ``loss``. The default value is None, at this time all parameters + will be updated. + no_grad_set (set, optional): Set of ``Tensor`` or ``Tensor.name`` that don't need + to be updated. The default value is None. + + Returns: + tuple: tuple (optimize_ops, params_grads), A list of operators appended + by minimize and a list of (param, grad) tensor pairs, param is + ``Parameter``, grad is the gradient value corresponding to the parameter. + In static graph mode, the returned tuple can be passed to ``fetch_list`` in ``Executor.run()`` to + indicate program pruning. If so, the program will be pruned by ``feed`` and + ``fetch_list`` before run, see details in ``Executor``. + + Examples: + + .. code-block:: python + + import paddle + import numpy as np + inp = paddle.to_tensor(np.random.random([1, 10]).astype('float32')) + linear = paddle.nn.Linear(10, 1) + out = linear(inp) + loss = paddle.mean(out) + loss.backward() + + sgd = paddle.optimizer.SGD(learning_rate=0.1,parameters=linear.parameters()) + sgd.minimize(loss) + + modelaverage = paddle.incubate.optimizer.ModelAverage(0.15, + parameters=linear.parameters(), + min_average_window=2, + max_average_window=4) + modelaverage.minimize(loss) + sgd.clear_grad() + modelaverage.clear_grad() + + """ + if framework.in_dygraph_mode(): + self.step() + + @framework.dygraph_only + @imperative_base.no_grad + def step(self): + """ + Execute the optimizer and update parameters once. + + Returns: + None + + Examples: + + .. code-block:: python + + import paddle + import numpy as np + inp = paddle.to_tensor(np.random.random([1, 10]).astype('float32')) + linear = paddle.nn.Linear(10, 1) + out = linear(inp) + loss = paddle.mean(out) + sgd = paddle.optimizer.SGD(learning_rate=0.1,parameters=linear.parameters()) + modelaverage = paddle.incubate.optimizer.ModelAverage(0.15, + parameters=linear.parameters(), + min_average_window=2, + max_average_window=4) + loss.backward() + sgd.step() + modelaverage.step() + sgd.clear_grad() + modelaverage.clear_grad() + """ + + params_grads = [] + for param in self._parameter_list: + if not param.trainable: + continue + if param._grad_ivar() is not None: + grad_var = param._grad_ivar() + params_grads.append((param, grad_var)) + + block = framework.default_main_program().global_block() + self._create_accumulators(block, self._parameter_list) + for param_and_grad in params_grads: + self._append_optimize_op(block, param_and_grad) + + @signature_safe_contextmanager + @imperative_base.no_grad + def apply(self, executor=None, need_restore=True): + """ + Apply the average of the cumulative ``Parameter`` to the parameters of the current model. + + Args: + executor(Executor): The network executor in static-graph mode. The default value is None in dygraph mode. + need_restore(bool): Restore flag variable, if set to True, the network will restore + the parameters of the network to the default value, if set to False, + it will not be restored. The default value is True. + + Examples: + + .. code-block:: python + + import paddle + import numpy as np + inp = paddle.to_tensor(np.random.random([1, 10]).astype('float32')) + linear = paddle.nn.Linear(10, 1) + out = linear(inp) + loss = paddle.mean(out) + loss.backward() + + sgd = paddle.optimizer.SGD(learning_rate=0.1,parameters=linear.parameters()) + + modelaverage = paddle.incubate.optimizer.ModelAverage(0.15, + parameters=linear.parameters(), + min_average_window=2, + max_average_window=4) + sgd.step() + modelaverage.step() + + with modelaverage.apply(): + for param in linear.parameters(): + print(param) + + for param in linear.parameters(): + print(param) + """ + if framework.in_dygraph_mode(): + for param in self._parameter_list: + num_accumulates = self._get_accumulator('num_accumulates', + param) + old_num_accumulates = self._get_accumulator( + 'old_num_accumulates', param) + num_updates = self._get_accumulator('num_updates', param) + sum_1 = self._get_accumulator('sum_1', param) + sum_2 = self._get_accumulator('sum_2', param) + sum_3 = self._get_accumulator('sum_3', param) + param_restore = self._get_accumulator('restore', param) + + paddle.assign(param, param_restore) + total_param = sum_1 + sum_2 + sum_3 + total_accumulates = num_accumulates + old_num_accumulates + total_param = paddle.cast(total_param, dtype='float32') + total_accumulates = paddle.cast( + total_accumulates, dtype='float32') + average_param = total_param / total_accumulates + paddle.assign(average_param, param) + try: + yield + finally: + if need_restore: + self.restore() + return + if executor is None: + raise RuntimeError( + "Executor should not be None in static graph mode.") + executor.run(self.apply_program) + try: + yield + finally: + if need_restore: + self.restore(executor) + + @imperative_base.no_grad + def restore(self, executor=None): + """ + Restore ``Parameter`` values of current model. + + Args: + executor(Executor): The network executor in static-graph mode. The default value is None in dygraph mode + + Examples: + + .. code-block:: python + + import paddle + import numpy as np + inp = paddle.to_tensor(np.random.random([1, 10]).astype('float32')) + linear = paddle.nn.Linear(10, 1) + out = linear(inp) + loss = paddle.mean(out) + loss.backward() + + sgd = paddle.optimizer.SGD(learning_rate=0.1,parameters=linear.parameters()) + + modelaverage = paddle.incubate.optimizer.ModelAverage(0.15, + parameters=linear.parameters(), + min_average_window=2, + max_average_window=4) + sgd.step() + modelaverage.step() + + with modelaverage.apply(need_restore=False): + for param in linear.parameters(): + print(param) + + for param in linear.parameters(): + print(param) + + modelaverage.restore() + + for param in linear.parameters(): + print(param) + """ + if framework.in_dygraph_mode(): + for param in self._parameter_list: + param_restore = self._get_accumulator('restore', param) + paddle.assign(param_restore, param) + return + if executor is None: + raise RuntimeError( + "Executor should not be None in static graph mode.") + executor.run(self.restore_program) + + def _add_average_apply_op(self, block, param): + param = block._clone_variable(param) + grad = block._clone_variable(self._get_accumulator('restore', param)) + sum_1 = block._clone_variable(self._get_accumulator('sum_1', param)) + sum_2 = block._clone_variable(self._get_accumulator('sum_2', param)) + sum_3 = block._clone_variable(self._get_accumulator('sum_3', param)) + num_accumulates = block._clone_variable( + self._get_accumulator('num_accumulates', param)) + old_num_accumulates = block._clone_variable( + self._get_accumulator('old_num_accumulates', param)) + num_updates = block._clone_variable( + self._get_accumulator('num_updates', param)) + # backup param value to grad + layers.assign(input=param, output=grad) + # param = (sum_1 + sum_2 + sum_3) / (num_accumulates + old_num_accumulates) + tmp = layers.sum(x=[num_accumulates, old_num_accumulates]) + sum = layers.sum(x=[sum_1, sum_2, sum_3]) + tmp = layers.cast( + x=tmp, dtype='float32' if self._dtype == None else self._dtype) + sum = layers.cast( + x=sum, dtype='float32' if self._dtype == None else self._dtype) + layers.ops._elementwise_div(x=sum, y=tmp, out=param) + + def _add_average_restore_op(self, block, param): + param = block._clone_variable(param) + grad = block._clone_variable(self._get_accumulator('restore', param)) + layers.assign(input=grad, output=param) diff --git a/python/setup.py.in b/python/setup.py.in index f43a97bff3..428b0a057b 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -143,6 +143,7 @@ packages=['paddle', 'paddle.reader', 'paddle.distributed', 'paddle.incubate', + 'paddle.incubate.optimizer', 'paddle.distributed.fleet', 'paddle.distributed.fleet.base', 'paddle.distributed.fleet.meta_optimizers', -- GitLab