未验证 提交 f4eda869 编写于 作者: Z Zeng Jinle 提交者: GitHub

Merge momentum ops/kernels (#36380)

* merge momentum ops

* update

* add ut to improve coverage

* remove optimizer change

* fix error msg

* update ut

* add __restrict__ for CUDA

* update ut

* move merged_momentum_op to optimizer dir

* fix coverage
上级 eb722e34
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/optimizers/merged_momentum_op.h"
namespace paddle {
namespace operators {
class MergedMomentumOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext *ctx) const override {}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto param_dtype =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "Param");
return framework::OpKernelType(param_dtype, ctx.GetPlace());
}
};
class MergedMomentumOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Param",
"(Tensor, default Tensor<float>) "
"Input parameter that has to be updated")
.AsDuplicable();
AddInput("Grad",
"(Tensor, default Tensor<float>) "
"Input gradient of the parameter")
.AsDuplicable();
AddInput("Velocity",
"(Tensor, default Tensor<float>) "
"Input velocity (corresponding to the parameter) "
"that has to be updated")
.AsDuplicable();
AddInput("LearningRate",
"(Tensor, default Tensor<float>) "
"Input learning rate");
AddInput("MasterParam", "FP32 master weight for AMP.")
.AsDispensable()
.AsDuplicable();
AddOutput("ParamOut",
"(Tensor) This output is updated parameter. "
"It shared memory with Input(Param).")
.AsDuplicable();
AddOutput("VelocityOut",
"(Tensor) This output is updated velocity. "
"It shared memory with Input(Velocity).")
.AsDuplicable();
AddOutput("MasterParamOut",
"The updated FP32 master weight for AMP. "
"It shared memory with Input(MasterParam).")
.AsDispensable()
.AsDuplicable();
AddAttr<float>("mu", "(float) Momentum coefficient");
AddAttr<bool>("multi_precision",
"(bool, default false) "
"Whether to use multi-precision during weight updating.")
.SetDefault(false);
AddAttr<float>(
"rescale_grad",
"(float, default 1.0) Multiply the gradient with `rescale_grad`"
"before updating. Often choose to be `1.0/batch_size`.")
.SetDefault(1.0f);
AddComment(R"DOC(Merged Momentum Optimizer.)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_WITHOUT_GRADIENT(merged_momentum, ops::MergedMomentumOp,
ops::MergedMomentumOpMaker);
REGISTER_OP_CPU_KERNEL(
merged_momentum, ops::MergedMomentumOpKernel<plat::CPUDeviceContext, float>,
ops::MergedMomentumOpKernel<plat::CPUDeviceContext, double>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/optimizers/merged_momentum_op.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
merged_momentum,
ops::MergedMomentumOpKernel<plat::CUDADeviceContext, plat::float16>,
ops::MergedMomentumOpKernel<plat::CUDADeviceContext, float>,
ops::MergedMomentumOpKernel<plat::CUDADeviceContext, double>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/fluid/platform/macros.h"
namespace paddle {
namespace operators {
template <typename MT, uint32_t kParamNum, bool kHasMasterParams>
struct MergedMomentumMasterParams {
MT *PADDLE_RESTRICT master_params[kParamNum];
HOSTDEVICE MT *MasterParam(size_t idx) const { return master_params[idx]; }
HOSTDEVICE void SetMasterParam(size_t idx, MT *p) { master_params[idx] = p; }
};
template <typename MT, uint32_t kParamNum>
struct MergedMomentumMasterParams<MT, kParamNum, false> {
HOSTDEVICE constexpr MT *MasterParam(size_t) const { return nullptr; }
HOSTDEVICE constexpr void SetMasterParam(size_t, MT *) {}
};
template <typename T, typename MT, bool kHasMasterParams,
uint32_t kParamNum = kHasMasterParams ? 55 : 110>
struct MergedMomentumKernelParam
: public MergedMomentumMasterParams<MT, kParamNum, kHasMasterParams> {
static constexpr auto N = kParamNum;
size_t sizes[N];
T *PADDLE_RESTRICT params[N];
const T *PADDLE_RESTRICT grads[N];
MT *PADDLE_RESTRICT velocitys[N];
const MT *PADDLE_RESTRICT lr;
MT mu;
MT rescale_grad;
uint32_t param_num;
HOSTDEVICE void operator()(size_t i) const {
const auto lr_val = *lr;
for (uint32_t idx = 0; idx < param_num; ++idx) {
auto size = sizes[idx];
if (i >= size) continue;
auto param_p = params[idx];
auto grad_p = grads[idx];
auto velocity_p = velocitys[idx];
auto master_param_p = this->MasterParam(idx);
const MT param =
master_param_p ? master_param_p[i] : static_cast<MT>(param_p[i]);
const MT grad = static_cast<MT>(grad_p[i]) * rescale_grad;
const MT velocity = velocity_p[i];
const MT velocity_out = velocity * mu + grad;
const MT param_out = param - lr_val * velocity_out;
velocity_p[i] = velocity_out;
param_p[i] = static_cast<T>(param_out);
if (master_param_p) {
master_param_p[i] = param_out;
}
}
}
};
template <typename DeviceContext, typename T>
class MergedMomentumOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto params = ctx.MultiInput<framework::Tensor>("Param");
auto params_out = ctx.MultiOutput<framework::Tensor>("ParamOut");
size_t n = params.size();
PADDLE_ENFORCE_EQ(
n, params_out.size(),
platform::errors::InvalidArgument(
"Output(ParamOut) number must be equal to Input(Param) number."));
for (size_t i = 0; i < n; ++i) {
PADDLE_ENFORCE_EQ(
params[i], params_out[i],
platform::errors::InvalidArgument(
"Input(Param) and Output(ParamOut) must be the same Tensors."));
}
auto grads = ctx.MultiInput<framework::Tensor>("Grad");
PADDLE_ENFORCE_EQ(
n, grads.size(),
platform::errors::InvalidArgument(
"Input(Grad) number must be equal to Input(Param) number."));
auto velocitys = ctx.MultiInput<framework::Tensor>("Velocity");
PADDLE_ENFORCE_EQ(n, velocitys.size(),
platform::errors::InvalidArgument(
"Input(Velocity) number and Input(Param) number."));
auto velocitys_out = ctx.MultiOutput<framework::Tensor>("VelocityOut");
PADDLE_ENFORCE_EQ(
n, velocitys_out.size(),
platform::errors::InvalidArgument("Output(VelocityOut) number must be "
"equal to Input(Param) number."));
for (size_t i = 0; i < n; ++i) {
PADDLE_ENFORCE_EQ(velocitys[i], velocitys_out[i],
platform::errors::InvalidArgument(
"Input(Velocity) and Output(VelocityOut) must be "
"the same Tensors."));
}
auto master_params = ctx.MultiInput<framework::Tensor>("MasterParam");
auto master_params_out =
ctx.MultiOutput<framework::Tensor>("MasterParamOut");
auto multi_precision = ctx.Attr<bool>("multi_precision");
if (multi_precision) {
PADDLE_ENFORCE_EQ(
n, master_params.size(),
platform::errors::InvalidArgument("Input(MasterParam) number must be "
"equal to Input(Param) number."));
PADDLE_ENFORCE_EQ(n, master_params_out.size(),
platform::errors::InvalidArgument(
"Output(MasterParamOut) number must be equal to "
"Input(MasterParam) number."));
for (size_t i = 0; i < n; ++i) {
PADDLE_ENFORCE_EQ(master_params[i], master_params_out[i],
platform::errors::InvalidArgument(
"Input(MasterParam) and Output(MasterParamOut) "
"must be the same Tensors."));
PADDLE_ENFORCE_NOT_NULL(master_params[i],
platform::errors::InvalidArgument(
"Input(MasterParam) must be provided when "
"multi_precision=True."));
}
} else {
master_params.clear();
master_params_out.clear();
}
auto lr = ctx.Input<framework::Tensor>("LearningRate");
auto mu = ctx.Attr<float>("mu");
auto rescale_grad = ctx.Attr<float>("rescale_grad");
using MPType = typename operators::details::MPTypeTrait<T>::Type;
auto &dev_ctx = ctx.template device_context<DeviceContext>();
#define PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(kMultiPrecision) \
MergedMomentumKernelParam<T, MPType, kMultiPrecision> kernel_params; \
constexpr auto kMaxMergedNum = decltype(kernel_params)::N; \
size_t kernel_num = (n + kMaxMergedNum - 1) / kMaxMergedNum; \
kernel_params.mu = static_cast<MPType>(mu); \
kernel_params.rescale_grad = static_cast<MPType>(rescale_grad); \
kernel_params.lr = lr->data<MPType>(); \
for (size_t i = 0; i < kernel_num; ++i) { \
size_t start = i * kMaxMergedNum; \
size_t end = std::min((i + 1) * kMaxMergedNum, n); \
kernel_params.param_num = static_cast<uint32_t>(end - start); \
size_t max_size = 0; \
for (size_t j = 0; j < kernel_params.param_num; ++j) { \
auto size = static_cast<size_t>(params_out[j + start]->numel()); \
max_size = std::max(max_size, size); \
kernel_params.sizes[j] = size; \
kernel_params.params[j] = params_out[j + start]->data<T>(); \
kernel_params.grads[j] = grads[j + start]->data<T>(); \
kernel_params.velocitys[j] = velocitys_out[j + start]->data<MPType>(); \
kernel_params.SetMasterParam( \
j, kMultiPrecision ? master_params_out[j + start]->data<MPType>() \
: nullptr); \
} \
platform::ForRange<DeviceContext> for_range(dev_ctx, max_size); \
for_range(kernel_params); \
VLOG(10) << "Launch MergedMomentum kernel " << i << " " \
<< kernel_params.param_num; \
}
if (multi_precision) {
PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(true);
} else {
PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(false);
}
#undef PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL
}
};
} // namespace operators
} // namespace paddle
......@@ -30,3 +30,9 @@ limitations under the License. */
#define FLT_MAX __FLT_MAX__
#endif // __FLT_MAX__
#endif // PADDLE_WITH_MUSL
#if defined(__NVCC__) || defined(__HIPCC__)
#define PADDLE_RESTRICT __restrict__
#else
#define PADDLE_RESTRICT
#endif
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
import numpy as np
from paddle.fluid.layer_helper import LayerHelper
from collections import OrderedDict
def run_momentum_op(params,
grads,
velocitys,
master_params,
learning_rate,
place,
multi_precision,
mu=0.9,
rescale_grad=0.01,
use_merged=False):
assert len(params) == len(grads)
assert len(params) == len(velocitys)
if multi_precision:
assert len(params) == len(master_params)
op_type = 'merged_momentum' if use_merged else 'momentum'
main = paddle.static.Program()
startup = paddle.static.Program()
with paddle.static.program_guard(main, startup):
helper = LayerHelper(op_type, **locals())
attrs = {
'mu': mu,
'multi_precision': multi_precision,
'rescale_grad': rescale_grad,
}
param_vars = [
helper.create_variable(
persistable=True, shape=p.shape, dtype=p.dtype) for p in params
]
grad_vars = [
helper.create_variable(
shape=g.shape, dtype=g.dtype) for g in grads
]
velocity_vars = [
helper.create_variable(
persistable=True, shape=v.shape, dtype=v.dtype)
for v in velocitys
]
lr_var = helper.create_variable(
persistable=True,
shape=learning_rate.shape,
dtype=learning_rate.dtype)
feed_dict = OrderedDict()
feed_dict.update(
OrderedDict([(p_var.name, p_val)
for p_var, p_val in zip(param_vars, params)]))
feed_dict.update(
OrderedDict([(v_var.name, v_val)
for v_var, v_val in zip(velocity_vars, velocitys)]))
fetch_list = list(feed_dict.keys())
feed_dict.update(
OrderedDict([(g_var.name, g_val)
for g_var, g_val in zip(grad_vars, grads)]))
feed_dict.update({lr_var.name: learning_rate})
if multi_precision:
master_param_vars = [
helper.create_variable(
persistable=True, shape=p.shape, dtype=p.dtype)
for p in master_params
]
feed_dict.update(
OrderedDict([(mp_var.name, mp_val)
for mp_var, mp_val in zip(master_param_vars,
master_params)]))
# CPUPlace does not use MasterParam
if isinstance(place, paddle.CUDAPlace):
fetch_list = fetch_list + [
mp_var.name for mp_var in master_param_vars
]
else:
master_param_vars = None
if not use_merged:
for i, (p, g,
v) in enumerate(zip(param_vars, grad_vars, velocity_vars)):
inputs = {
'Param': p,
'Grad': g,
'Velocity': v,
'LearningRate': lr_var
}
outputs = {'ParamOut': p, 'VelocityOut': v}
if multi_precision:
inputs['MasterParam'] = master_param_vars[i]
outputs['MasterParamOut'] = master_param_vars[i]
helper.append_op(
type=op_type, inputs=inputs, outputs=outputs, attrs=attrs)
else:
inputs = {
'Param': param_vars,
'Grad': grad_vars,
'Velocity': velocity_vars,
'LearningRate': lr_var
}
outputs = {'ParamOut': param_vars, 'VelocityOut': velocity_vars}
if multi_precision:
inputs['MasterParam'] = master_param_vars
outputs['MasterParamOut'] = master_param_vars
helper.append_op(
type=op_type, inputs=inputs, outputs=outputs, attrs=attrs)
exe = paddle.static.Executor(place)
with paddle.static.scope_guard(paddle.static.Scope()):
exe.run(startup)
return exe.run(main, feed=feed_dict, fetch_list=fetch_list)
class TestMergedMomentum(unittest.TestCase):
def setUp(self):
paddle.enable_static()
self.shapes = [[3, 4], [2, 7], [5, 6], [7, 8]]
self.seed = 10
def gen_rand_data(self, shapes, dtype):
return [np.random.random(s).astype(dtype) for s in shapes]
def prepare_data(self, shapes, multi_precision, seed, place):
np.random.seed(seed)
mp_dtype = np.float32
dtype = np.float16 if multi_precision and isinstance(
place, paddle.CUDAPlace) else np.float32
params = self.gen_rand_data(shapes, dtype)
grads = self.gen_rand_data(shapes, dtype)
velocitys = self.gen_rand_data(shapes, mp_dtype)
learning_rate = self.gen_rand_data([[1]], mp_dtype)[0]
if multi_precision:
master_params = [p.astype(mp_dtype) for p in params]
else:
master_params = None
return params, grads, velocitys, master_params, learning_rate
def check_with_place(self, place, multi_precision):
params, grads, velocitys, master_params, learning_rate = self.prepare_data(
self.shapes, multi_precision, self.seed, place)
def run_op(use_merged):
# FIXME(zengjinle): CPU Momentum Op does not support rescale_grad
rescale_grad = 1.0 if isinstance(place, paddle.CPUPlace) else 0.01
return run_momentum_op(
params,
grads,
velocitys,
master_params,
learning_rate,
place,
multi_precision,
rescale_grad=rescale_grad,
use_merged=use_merged)
outs1 = run_op(True)
outs2 = run_op(False)
self.assertEqual(len(outs1), len(outs2))
for i, (out1, out2) in enumerate(zip(outs1, outs2)):
self.assertTrue(np.allclose(out1, out2, atol=1e-7))
def get_places(self):
places = [paddle.CPUPlace()]
if paddle.is_compiled_with_cuda():
places.append(paddle.CUDAPlace(0))
return places
def test_main(self):
for multi_precision in [False, True]:
for place in self.get_places():
self.check_with_place(place, multi_precision)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册