提交 250e72d2 编写于 作者: W WangXi 提交者: gongweibao

Fix DGC algorithm flow to make it the same as paper (#20758)

上级 ba45dce3
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "paddle/fluid/operators/optimizers/dgc_momentum_op.h"
namespace paddle {
namespace operators {
class DGCMomentumOp : public MomentumOp {
public:
using MomentumOp::MomentumOp;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("current_step"), true,
"current_step should be set.");
return MomentumOp::InferShape(ctx);
}
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const framework::Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
if (var_name == "current_step") {
VLOG(10) << "var_name:" << var_name << " need not to transform";
return expected_kernel_type;
}
return framework::OperatorWithKernel::GetKernelTypeForVar(
var_name, tensor, expected_kernel_type);
}
};
class DGCMomentumOpMaker : public MomentumOpMaker {
public:
void Make() override {
AddInput("current_step", "(Tensor) Current step.");
AddAttr<float>("rampup_begin_step",
"(float, -1.0)"
"The period when begin DGC.")
.SetDefault(-1.0);
return MomentumOpMaker::Make();
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(dgc_momentum, ops::DGCMomentumOp,
ops::DGCMomentumOpMaker);
REGISTER_OP_CPU_KERNEL(
dgc_momentum,
ops::DGCMomentumKernel<paddle::platform::CPUDeviceContext, float>);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/optimizers/dgc_momentum_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
dgc_momentum,
ops::DGCMomentumKernel<paddle::platform::CUDADeviceContext, float>);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "paddle/fluid/operators/optimizers/momentum_op.h"
#include "paddle/fluid/operators/optimizers/sgd_op.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class DGCMomentumKernel : public framework::OpKernel<T> {
public:
DGCMomentumKernel()
: _momentum_op_kernel(new MomentumOpKernel<DeviceContext, T>()),
_sgd_op_kernel(new SGDOpKernel<DeviceContext, T>()) {}
void Compute(const framework::ExecutionContext& context) const override {
auto rampup_begin_step = context.Attr<float>("rampup_begin_step");
if (static_cast<int>(rampup_begin_step) < 0) {
return;
}
auto current_step_tensor = context.Input<framework::Tensor>("current_step");
auto* current_step = current_step_tensor->data<T>();
VLOG(10) << "current_step:" << *current_step
<< ", rampup_begin_step:" << rampup_begin_step;
if (static_cast<int>(*current_step) < static_cast<int>(rampup_begin_step)) {
VLOG(10) << " so use momentum optimizer";
return _momentum_op_kernel->Compute(context);
}
VLOG(10) << " so use sgd optimizer";
return _sgd_op_kernel->Compute(context);
}
private:
std::unique_ptr<MomentumOpKernel<DeviceContext, T>> _momentum_op_kernel;
std::unique_ptr<SGDOpKernel<DeviceContext, T>> _sgd_op_kernel;
};
} // namespace operators
} // namespace paddle
......@@ -37,36 +37,34 @@ class MomentumOpInferVarType : public framework::VarTypeInference {
}
};
class MomentumOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Param",
"(Tensor, default Tensor<float>) "
"Input parameter that has to be updated");
AddInput("Grad",
"(Tensor, default Tensor<float>) "
"Input gradient of the parameter");
AddInput("Velocity",
"(Tensor, default Tensor<float>) "
"Input velocity (corresponding to the parameter) "
"that has to be updated");
AddInput("LearningRate",
"(Tensor, default Tensor<float>) "
"Input learning rate");
void MomentumOpMaker::Make() {
AddInput("Param",
"(Tensor, default Tensor<float>) "
"Input parameter that has to be updated");
AddInput("Grad",
"(Tensor, default Tensor<float>) "
"Input gradient of the parameter");
AddInput("Velocity",
"(Tensor, default Tensor<float>) "
"Input velocity (corresponding to the parameter) "
"that has to be updated");
AddInput("LearningRate",
"(Tensor, default Tensor<float>) "
"Input learning rate");
AddOutput("ParamOut",
"(Tensor) This output is updated parameter. "
"It shared memory with Input(Param).");
AddOutput("VelocityOut",
"(Tensor) This output is updated velocity. "
"It shared memory with Input(Velocity).");
AddOutput("ParamOut",
"(Tensor) This output is updated parameter. "
"It shared memory with Input(Param).");
AddOutput("VelocityOut",
"(Tensor) This output is updated velocity. "
"It shared memory with Input(Velocity).");
AddAttr<float>("mu", "(float) Momentum coefficient");
AddAttr<bool>("use_nesterov",
"(bool, default false) "
"Use Nesterov Momentum")
.SetDefault(false);
AddComment(R"DOC(
AddAttr<float>("mu", "(float) Momentum coefficient");
AddAttr<bool>("use_nesterov",
"(bool, default false) "
"Use Nesterov Momentum")
.SetDefault(false);
AddComment(R"DOC(
Momentum Optimizer.
This optimizer has a flag for Nestrov Momentum.
......@@ -81,8 +79,8 @@ else: \\
$$
)DOC");
}
};
}
} // namespace operators
} // namespace paddle
......
......@@ -29,6 +29,11 @@ using framework::SelectedRows;
struct NoNesterov;
struct UseNesterov;
class MomentumOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override;
};
class MomentumOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......
......@@ -110,4 +110,6 @@ $$param\_out = param - learning\_rate * grad$$
namespace ops = paddle::operators;
REGISTER_OPERATOR(sgd, ops::SGDOp, ops::SGDOpMaker,
paddle::framework::EmptyGradOpMaker, ops::SGDOpInferVarType);
REGISTER_OP_CPU_KERNEL(sgd, ops::SGDOpKernel<float>, ops::SGDOpKernel<double>);
REGISTER_OP_CPU_KERNEL(
sgd, ops::SGDOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::SGDOpKernel<paddle::platform::CPUDeviceContext, double>);
......@@ -53,7 +53,8 @@ __global__ void SparseSGDFunctorKernel(const T* selected_rows,
} // namespace
template <typename T>
class SGDOpCUDAKernel : public framework::OpKernel<T> {
class SGDOpKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto* param_var = ctx.InputVar("Param");
......@@ -123,6 +124,7 @@ class SGDOpCUDAKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(sgd, ops::SGDOpCUDAKernel<float>,
ops::SGDOpCUDAKernel<double>,
ops::SGDOpCUDAKernel<plat::float16>);
REGISTER_OP_CUDA_KERNEL(
sgd, ops::SGDOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::SGDOpKernel<paddle::platform::CUDADeviceContext, double>,
ops::SGDOpKernel<paddle::platform::CUDADeviceContext, plat::float16>);
......@@ -21,8 +21,15 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename T>
template <typename DeviceContext, typename T>
class SGDOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override;
};
template <typename T>
class SGDOpKernel<platform::CPUDeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
const auto *learning_rate = ctx.Input<framework::Tensor>("LearningRate");
......
......@@ -959,6 +959,47 @@ class DGCMomentumOptimizer(MomentumOptimizer):
super(DGCMomentumOptimizer, self).__init__(
learning_rate, momentum, use_nesterov, regularization, name)
def _is_use_dgc(self, param_var, grad_var):
var_numel = abs(reduce(lambda x, y: x * y, param_var.shape))
if var_numel < 16384 or \
param_var.type == core.VarDesc.VarType.SELECTED_ROWS or \
grad_var.type == core.VarDesc.VarType.SELECTED_ROWS or \
param_var.dtype != core.VarDesc.VarType.FP32 :
return False
return True
def _append_optimize_op(self, block, param_and_grad):
assert isinstance(block, framework.Block)
if not self._is_use_dgc(param_and_grad[0], param_and_grad[1]):
return super(DGCMomentumOptimizer, self)._append_optimize_op(
block, param_and_grad)
velocity_acc = self._get_accumulator(self._velocity_acc_str,
param_and_grad[0])
# create the dgc momentum optimize op
dgc_momentum_op = block.append_op(
type="dgc_momentum",
inputs={
"Param": param_and_grad[0],
"Grad": param_and_grad[1],
"Velocity": velocity_acc,
"LearningRate": self._create_param_lr(param_and_grad),
"current_step": self._global_step_var,
},
outputs={
"ParamOut": param_and_grad[0],
"VelocityOut": velocity_acc
},
attrs={
"mu": self._momentum,
"use_nesterov": self._use_nesterov,
"rampup_begin_step": float(self._rampup_begin_step)
},
stop_gradient=True)
return dgc_momentum_op
def _add_auto_increment_var(self, counter_name, begin, step=1):
helper = LayerHelper('global_step_counter')
counter, is_new_var = helper.create_or_get_global_variable(
......@@ -997,11 +1038,7 @@ class DGCMomentumOptimizer(MomentumOptimizer):
force_cpu=True)
for param_var, grad_var in param_and_grads:
var_numel = abs(reduce(lambda x, y: x * y, param_var.shape))
if var_numel < 16384 or \
param_var.type == core.VarDesc.VarType.SELECTED_ROWS or \
grad_var.type == core.VarDesc.VarType.SELECTED_ROWS or \
param_var.dtype != core.VarDesc.VarType.FP32 :
if not self._is_use_dgc(param_var, grad_var):
continue
u_var = tensor.create_global_var(
......
......@@ -10,6 +10,8 @@ list(APPEND DIST_TEST_OPS test_parallel_dygraph_se_resnext)
set(MIXED_DIST_TEST_OPS ${DIST_TEST_OPS})
#remove distribute unittests.
list(APPEND MIXED_DIST_TEST_OPS test_dgc_op)
list(APPEND MIXED_DIST_TEST_OPS test_dgc_momentum_op)
list(APPEND MIXED_DIST_TEST_OPS test_dgc_optimizer)
list(APPEND MIXED_DIST_TEST_OPS test_simple_dist_transpiler)
list(APPEND MIXED_DIST_TEST_OPS test_listen_and_serv_op)
list(APPEND MIXED_DIST_TEST_OPS test_nce_remote_table_op)
......@@ -248,6 +250,8 @@ if(WITH_DISTRIBUTE)
py_test_modules(test_nce_remote_table_op MODULES test_nce_remote_table_op ENVS ${dist_ENVS})
if(WITH_DGC)
py_test_modules(test_dgc_op MODULES test_dgc_op)
py_test_modules(test_dgc_momentum_op MODULES test_dgc_momentum_op)
py_test_modules(test_dgc_optimizer MODULES test_dgc_optimizer)
endif()
if(NOT APPLE)
bash_test_modules(test_listen_and_serv_op MODULES test_listen_and_serv.sh)
......
......@@ -98,7 +98,7 @@ class TestDistMnist2x2(TestDistRunnerBase):
opt = fluid.optimizer.Momentum(learning_rate=self.lr, momentum=0.9)
else:
opt = fluid.optimizer.DGCMomentumOptimizer(
learning_rate=self.lr, momentum=0.9, rampup_begin_step=0)
learning_rate=self.lr, momentum=0.9, rampup_begin_step=2)
# Reader
train_reader = paddle.batch(
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid.core as core
from paddle.fluid.op import Operator
import paddle.fluid as fluid
class TestDGCMomentumOp1(unittest.TestCase):
def get_tensor(self, name, value, place=None):
tensor = self.scope.var(name).get_tensor()
tensor.set(value, self.place if place is None else place)
return name, tensor
def setup(self, place, step=0.0):
self.scope = fluid.global_scope()
self.place = place
print("place:", place)
self.op_type = "dgc_momentum"
self.dtype = np.float32
param = np.random.random((123, 321)).astype(self.dtype)
grad = np.random.random((123, 321)).astype(self.dtype)
velocity = np.zeros((123, 321)).astype(self.dtype)
learning_rate = np.array([0.001]).astype(self.dtype)
current_step = np.full((1), step).astype("float32")
mu = 0.0001
use_nesterov = False
rampup_begin_step = 10.0
self.param_name, self.param_tensor = self.get_tensor('Param', param)
self.grad_name, self.grad_tensor = self.get_tensor('Grad', grad)
self.velocity_name, self.velocity_tensor = self.get_tensor('Velocity',
velocity)
self.learning_rate_name, self.learning_rate_tensor = self.get_tensor(
'LearningRate', learning_rate)
self.current_step_name, self.current_step_tensor = self.get_tensor(
'current_step', current_step, core.CPUPlace())
self.kwargs = {
# inputs
'Param': self.param_name,
'Grad': self.grad_name,
'Velocity': self.velocity_name,
'LearningRate': self.learning_rate_name,
'current_step': self.current_step_name,
# attrs
'mu': mu,
'use_nesterov': use_nesterov,
'rampup_begin_step': rampup_begin_step,
# outputs
'ParamOut': self.param_name,
'VelocityOut': self.velocity_name
}
velocity_out = mu * velocity + grad
if use_nesterov:
param_out = param - grad * learning_rate - \
velocity_out * mu * learning_rate
else:
param_out = param - learning_rate * velocity_out
sgd_out = param - learning_rate * grad
self.outputs = {
'ParamOut': param_out,
'VelocityOut': velocity_out,
'SGDOut': sgd_out
}
def check(self, actual_t, expect_t, place, out_name, atol=1e-5):
self.assertTrue(
np.allclose(
actual_t, expect_t, atol=atol),
"Output (" + out_name + ") has diff at " + str(place) + "\nExpect "
+ str(expect_t) + "\n" + "But Got" + str(actual_t))
def check_momentum_step(self, place):
self.setup(place=place)
dgc_momentum_op = Operator(self.op_type, **self.kwargs)
dgc_momentum_op.run(self.scope, self.place)
self.check(
np.array(self.param_tensor), self.outputs['ParamOut'], self.place,
self.param_name)
self.check(
np.array(self.velocity_tensor), self.outputs['VelocityOut'],
self.place, self.velocity_name)
def check_sgd_step(self, place):
self.setup(place=place, step=15.0)
dgc_momentum_op = Operator(self.op_type, **self.kwargs)
dgc_momentum_op.run(self.scope, self.place)
self.check(
np.array(self.param_tensor), self.outputs['SGDOut'], self.place,
self.param_name)
def test_cuda_place(self):
if not core.is_compiled_with_cuda():
return
place = core.CUDAPlace(0)
self.check_momentum_step(place)
self.check_sgd_step(place)
def test_cpu_place(self):
place = core.CPUPlace()
self.check_momentum_step(place)
self.check_sgd_step(place)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle.fluid.framework as framework
import paddle.fluid.optimizer as optimizer
import paddle.compat as cpt
from paddle.fluid.backward import append_backward
from paddle.fluid.transpiler.details import program_to_code
class TestDGCMomentumOptimizer(unittest.TestCase):
class MockDGCMomentum(optimizer.DGCMomentumOptimizer):
def get_accumulators(self):
return self._accumulators
def get_velocity_str(self):
return self._velocity_acc_str
def check_dgc_momentum_optimizer(self, dims=[5, 10, 8], name="momentum"):
init_program = framework.Program()
program = framework.Program()
block = program.global_block()
mul_x = block.create_parameter(
dtype="float32",
shape=[dims[0], dims[1]],
lod_level=0,
name="mul.x",
optimize_attr={'learning_rate': 1.1})
mul_y = block.create_var(
dtype="float32",
shape=[dims[1], dims[2]],
lod_level=0,
name="mul.y")
mul_out = block.create_var(
dtype="float32",
shape=[dims[0], dims[2]],
lod_level=0,
name="mul.out")
block.append_op(
type="mul",
inputs={"X": mul_x,
"Y": mul_y},
outputs={"Out": mul_out},
attrs={"x_num_col_dims": 1})
learning_rate = 0.01
dgc_momentum_optimizer = self.MockDGCMomentum(
learning_rate=learning_rate, momentum=0.2, rampup_begin_step=0)
mean_out = block.create_var(
dtype="float32", shape=[1], lod_level=0, name="mean.out")
block.append_op(
type="mean", inputs={"X": mul_out}, outputs={"Out": mean_out})
# params_grads = append_backward(mean_out)
params_grads = dgc_momentum_optimizer.backward(mean_out)
self.assertEqual(len(params_grads), 1)
self.assertEqual(len(dgc_momentum_optimizer.get_accumulators()), 0)
with framework.program_guard(program, init_program):
opts = dgc_momentum_optimizer.apply_gradients(params_grads)
self.assertEqual(len(opts), 2)
sgd_op = opts[-1]
self.assertEqual([op.type for op in opts], ["scale", name])
self.assertFalse(sgd_op.attr('use_nesterov'))
# Check accumulators
accumulators = dgc_momentum_optimizer.get_accumulators()
self.assertEqual(len(accumulators), 1)
self.assertTrue(
dgc_momentum_optimizer.get_velocity_str() in accumulators)
velocity_acc = accumulators[dgc_momentum_optimizer.get_velocity_str()]
self.assertEqual(len(velocity_acc), 1)
self.assertTrue(mul_x.name in velocity_acc)
# Check init_program
init_ops = init_program.global_block().ops
self.assertEqual(len(init_ops), 2)
self.assertEqual(init_ops[0].type, "fill_constant")
self.assertAlmostEqual(init_ops[0].attr('value'), learning_rate)
self.assertEqual(init_ops[1].type, "fill_constant")
self.assertAlmostEqual(init_ops[1].attr('value'), 0.0)
with open("test_dgc_optimizer_" + name + ".log", "w") as f:
program_to_code(program, fout=f)
def test_momentum_without_dgc(self):
self.check_dgc_momentum_optimizer()
def test_momentum_with_dgc(self):
# 16 * 1024 = 16384, use dgc momentum
self.check_dgc_momentum_optimizer(
dims=[16, 1024, 8], name="dgc_momentum")
if __name__ == '__main__':
unittest.main()
......@@ -17,9 +17,20 @@ import unittest
from test_dist_base import TestDistBase
import os
import subprocess
flag_name = os.path.splitext(__file__)[0]
def count_of_sparse_all_reduce_calls(file_name):
cmd = 'grep sparse_all_reduce_op_handle ' + file_name + ' | grep in_numel | wc -l'
child = subprocess.Popen(cmd, stdout=subprocess.PIPE, shell=True)
result = child.communicate()[0]
print('test_info: result = ' + str(result))
# note. in python3, result is b'num', != 'num'
return int(result)
class TestDistMnistNCCL2DGC(TestDistBase):
def _setup_config(self):
self._sync_mode = True
......@@ -37,6 +48,15 @@ class TestDistMnistNCCL2DGC(TestDistBase):
check_error_log=True,
log_name=flag_name)
def tearDown(self):
result = count_of_sparse_all_reduce_calls(
'test_dist_mnist_dgc_nccl_tr0_err.log')
# only 1 layer use dgc now, run_step=5, rampup_begin_step=2, so 1 * (5 - 2) = 3
# temp close this test. In python3 CI, the log is right, but the result
# has a problem, may be in multi process mode, log is not writed in time.
# self.assertEqual(result, 3)
class TestDistMnistNCCL2DGCMultiCards(TestDistBase):
def _setup_config(self):
......@@ -55,6 +75,12 @@ class TestDistMnistNCCL2DGCMultiCards(TestDistBase):
check_error_log=True,
log_name=flag_name)
def tearDown(self):
result = count_of_sparse_all_reduce_calls(
'test_dist_mnist_dgc_nccl_dgc_2cards_local.log')
# same as above, but use two cards
self.assertEqual(result, 6)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册