未验证 提交 bf6ec262 编写于 作者: C Chenxiao Niu 提交者: GitHub

[MLU] add mergedAdam kernel. (#45965)

上级 da33f7b0
......@@ -333,6 +333,224 @@ class AdamWMLUKernel : public AdamMLUKernel<T> {
}
};
template <typename T>
class MergedAdamMLUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
// Get inputs and outputs
auto params = ctx.MultiInput<framework::Tensor>("Param");
auto grads = ctx.MultiInput<framework::Tensor>("Grad");
auto lrs = ctx.MultiInput<framework::Tensor>("LearningRate");
auto mom1s = ctx.MultiInput<framework::Tensor>("Moment1");
auto mom2s = ctx.MultiInput<framework::Tensor>("Moment2");
auto beta1_pows = ctx.MultiInput<framework::Tensor>("Beta1Pow");
auto beta2_pows = ctx.MultiInput<framework::Tensor>("Beta2Pow");
auto master_params = ctx.MultiInput<framework::Tensor>("MasterParam");
auto param_outs = ctx.MultiOutput<framework::Tensor>("ParamOut");
auto mom1_outs = ctx.MultiOutput<framework::Tensor>("Moment1Out");
auto mom2_outs = ctx.MultiOutput<framework::Tensor>("Moment2Out");
auto beta1_pow_outs = ctx.MultiOutput<framework::Tensor>("Beta1PowOut");
auto beta2_pow_outs = ctx.MultiOutput<framework::Tensor>("Beta2PowOut");
// Check validation of inputs and outputs
size_t param_num = params.size();
PADDLE_ENFORCE_EQ(param_num,
param_outs.size(),
platform::errors::InvalidArgument(
"The size of Output(ParamOut) must be equal to "
"Input(Param), but got the size of Output(ParamOut) "
"is %d, the size of Input(Param) is %d.",
param_outs.size(),
param_num));
bool skip_update = false;
if (ctx.HasInput("SkipUpdate")) {
auto* skip_update_tensor = ctx.Input<framework::Tensor>("SkipUpdate");
PADDLE_ENFORCE_EQ(skip_update_tensor->numel(),
1,
platform::errors::InvalidArgument(
"Input(SkipUpdate) size must be 1, but get %d",
skip_update_tensor->numel()));
std::vector<bool> skip_update_vec;
paddle::framework::TensorToVector(
*skip_update_tensor, ctx.device_context(), &skip_update_vec);
ctx.device_context().Wait();
skip_update = skip_update_vec[0];
}
// skip_update=true, just copy input to output, and TensorCopy will call
// mutable_data
if (skip_update) {
VLOG(4) << "MergedAdam skip update";
for (size_t i = 0; i < param_num; ++i) {
framework::TensorCopy(
*params[i],
ctx.GetPlace(),
ctx.template device_context<platform::MLUDeviceContext>(),
param_outs[i]);
framework::TensorCopy(
*mom1s[i],
ctx.GetPlace(),
ctx.template device_context<platform::MLUDeviceContext>(),
mom1_outs[i]);
framework::TensorCopy(
*mom2s[i],
ctx.GetPlace(),
ctx.template device_context<platform::MLUDeviceContext>(),
mom2_outs[i]);
framework::TensorCopy(
*beta1_pows[i],
beta1_pows[i]->place(),
ctx.template device_context<platform::MLUDeviceContext>(),
beta1_pow_outs[i]);
framework::TensorCopy(
*beta2_pows[i],
beta2_pows[i]->place(),
ctx.template device_context<platform::MLUDeviceContext>(),
beta2_pow_outs[i]);
}
return;
}
bool use_global_beta_pow = ctx.Attr<bool>("use_global_beta_pow");
VLOG(4) << "use_global_beta_pow:" << use_global_beta_pow;
// Get beta1, beta2 and epsilon from attribute.
const Tensor* beta1_tensor = nullptr;
const Tensor* beta2_tensor = nullptr;
const Tensor* epsilon_tensor = nullptr;
Tensor beta1_tmp(experimental::DataType::FLOAT32);
Tensor beta2_tmp(experimental::DataType::FLOAT32);
Tensor epsilon_tmp(experimental::DataType::FLOAT32);
T beta1 = static_cast<T>(ctx.Attr<float>("beta1"));
T beta2 = static_cast<T>(ctx.Attr<float>("beta2"));
T epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
beta1_tmp.mutable_data<T>({1}, ctx.GetPlace());
beta2_tmp.mutable_data<T>({1}, ctx.GetPlace());
epsilon_tmp.mutable_data<T>({1}, ctx.GetPlace());
MLUCnnlTensorDesc beta1_tmp_desc(beta1_tmp);
MLUCnnlTensorDesc beta2_tmp_desc(beta2_tmp);
MLUCnnlTensorDesc epsilon_tmp_desc(epsilon_tmp);
MLUCnnl::Fill(ctx,
CNNL_POINTER_MODE_HOST,
&beta1,
beta1_tmp_desc.get(),
GetBasePtr(&beta1_tmp));
MLUCnnl::Fill(ctx,
CNNL_POINTER_MODE_HOST,
&beta2,
beta2_tmp_desc.get(),
GetBasePtr(&beta2_tmp));
MLUCnnl::Fill(ctx,
CNNL_POINTER_MODE_HOST,
&epsilon,
epsilon_tmp_desc.get(),
GetBasePtr(&epsilon_tmp));
beta1_tensor = &beta1_tmp;
beta2_tensor = &beta2_tmp;
epsilon_tensor = &epsilon_tmp;
// Loop to compute
for (size_t i = 0; i < param_num; ++i) {
VLOG(4) << "[MergedAdam] loop: " << i;
param_outs[i]->ShareDataWith(*params[i]);
mom1_outs[i]->ShareDataWith(*mom1s[i]);
mom2_outs[i]->ShareDataWith(*mom2s[i]);
LoDTensor beta1_pow_tmp;
LoDTensor beta2_pow_tmp;
if (beta1_pows[i]->place() == platform::CPUPlace()) {
T beta1 = *beta1_pows[i]->data<T>();
beta1_pow_tmp.mutable_data<T>({1}, ctx.GetPlace());
MLUCnnlTensorDesc beta1_pow_tmp_desc(beta1_pow_tmp);
MLUCnnl::Fill(ctx,
CNNL_POINTER_MODE_HOST,
&beta1,
beta1_pow_tmp_desc.get(),
GetBasePtr(&beta1_pow_tmp));
beta1_pows[i] = &beta1_pow_tmp;
}
if (beta2_pows[i]->place() == platform::CPUPlace()) {
T beta2 = *beta2_pows[i]->data<T>();
beta2_pow_tmp.mutable_data<T>({1}, ctx.GetPlace());
MLUCnnlTensorDesc beta2_pow_tmp_desc(beta2_pow_tmp);
MLUCnnl::Fill(ctx,
CNNL_POINTER_MODE_HOST,
&beta2,
beta2_pow_tmp_desc.get(),
GetBasePtr(&beta2_pow_tmp));
beta2_pows[i] = &beta2_pow_tmp;
}
VLOG(3) << "beta1_pow.numel() : " << beta1_pows[i]->numel()
<< "beta2_pow.numel() : " << beta2_pows[i]->numel();
VLOG(3) << "param.numel(): " << params[i]->numel();
PADDLE_ENFORCE_EQ(beta1_pow_outs[i]->numel(),
1,
platform::errors::InvalidArgument(
"beta1 pow output size should be 1, but received "
"value is:%d.",
beta1_pow_outs[i]->numel()));
PADDLE_ENFORCE_EQ(beta2_pow_outs[i]->numel(),
1,
platform::errors::InvalidArgument(
"beta2 pow output size should be 1, but received "
"value is:%d.",
beta2_pow_outs[i]->numel()));
MLUCnnlTensorDesc param_desc(*params[i]);
MLUCnnlTensorDesc mom1_desc(*mom1s[i]);
MLUCnnlTensorDesc mom2_desc(*mom2s[i]);
MLUCnnlTensorDesc grad_desc(*grads[i]);
MLUCnnl::ApplyAdam(ctx,
param_desc.get(),
GetBasePtr(param_outs[i]),
mom1_desc.get(),
GetBasePtr(mom1_outs[i]),
mom2_desc.get(),
GetBasePtr(mom2_outs[i]),
grad_desc.get(),
GetBasePtr(grads[i]),
GetBasePtr(lrs[i]),
GetBasePtr(beta1_tensor),
GetBasePtr(beta2_tensor),
GetBasePtr(beta1_pows[i]),
GetBasePtr(beta2_pows[i]),
GetBasePtr(epsilon_tensor),
/*use_nesterov*/ false);
if (!use_global_beta_pow) {
beta1_pow_outs[i]->mutable_data<T>(ctx.GetPlace());
beta2_pow_outs[i]->mutable_data<T>(ctx.GetPlace());
MLUCnnlTensorDesc beta1_desc(*beta1_tensor);
MLUCnnlOpTensorDesc mul_op_desc(
CNNL_OP_TENSOR_MUL, ToCnnlDataType<T>(), CNNL_NOT_PROPAGATE_NAN);
MLUCnnl::OpTensor(ctx,
mul_op_desc.get(),
beta1_desc.get(),
GetBasePtr(beta1_pows[i]),
beta1_desc.get(),
GetBasePtr(beta1_tensor),
beta1_desc.get(),
GetBasePtr(beta1_pow_outs[i]),
ToCnnlDataType<T>());
MLUCnnl::OpTensor(ctx,
mul_op_desc.get(),
beta1_desc.get(),
GetBasePtr(beta2_pows[i]),
beta1_desc.get(),
GetBasePtr(beta2_tensor),
beta1_desc.get(),
GetBasePtr(beta2_pow_outs[i]),
ToCnnlDataType<T>());
}
}
}
};
} // namespace operators
} // namespace paddle
......@@ -346,3 +564,7 @@ REGISTER_OP_MLU_KERNEL(adam,
REGISTER_OP_MLU_KERNEL(adamw,
ops::AdamWMLUKernel<float>,
ops::AdamWMLUKernel<plat::float16>);
REGISTER_OP_MLU_KERNEL(merged_adam,
ops::MergedAdamMLUKernel<float>,
ops::MergedAdamMLUKernel<plat::float16>);
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append('..')
import unittest
import paddle
import numpy as np
from paddle import _C_ops, _legacy_C_ops
from paddle.fluid.framework import in_dygraph_mode
def run_adam_op(params,
grads,
lrs,
moment1s,
moment2s,
beta1_pows,
beta2_pows,
master_params,
epsilon,
beta1,
beta2,
place,
multi_precision=False,
use_merged=False):
assert len(params) == len(grads)
assert len(params) == len(lrs)
assert len(params) == len(moment1s)
assert len(params) == len(moment2s)
assert len(params) == len(beta1_pows)
assert len(params) == len(beta1_pows)
assert len(params) == len(master_params)
paddle.disable_static()
# paddle.set_device(place)
param_vars = [paddle.fluid.dygraph.to_variable(p) for p in params]
grad_vars = [paddle.fluid.dygraph.to_variable(g) for g in grads]
lr_vars = [paddle.fluid.dygraph.to_variable(l) for l in lrs]
moment1_vars = [paddle.fluid.dygraph.to_variable(m) for m in moment1s]
moment2_vars = [paddle.fluid.dygraph.to_variable(m) for m in moment2s]
beta1_pow_vars = [paddle.fluid.dygraph.to_variable(b) for b in beta1_pows]
beta2_pow_vars = [paddle.fluid.dygraph.to_variable(b) for b in beta2_pows]
master_param_vars = [
paddle.fluid.dygraph.to_variable(m_p) for m_p in master_params
]
if not use_merged:
for i in range(len(param_vars)):
_, _, _, _, _, _ = _legacy_C_ops.adam(
param_vars[i], grad_vars[i], lr_vars[i], moment1_vars[i],
moment2_vars[i], beta1_pow_vars[i], beta2_pow_vars[i],
master_param_vars[i], param_vars[i], moment1_vars[i],
moment2_vars[i], beta1_pow_vars[i], beta2_pow_vars[i],
master_param_vars[i], 'epsilon', epsilon, 'beta1', beta1,
'beta2', beta2, 'multi_precision', multi_precision)
else:
if in_dygraph_mode():
_, _, _, _, _, _ = _C_ops.merged_adam_(
param_vars, grad_vars, lr_vars, moment1_vars, moment2_vars,
beta1_pow_vars, beta2_pow_vars, master_param_vars, beta1, beta2,
epsilon, multi_precision, False)
else:
_, _, _, _, _, _ = _legacy_C_ops.merged_adam(
param_vars, grad_vars, lr_vars, moment1_vars, moment2_vars,
beta1_pow_vars, beta2_pow_vars, master_param_vars, param_vars,
moment1_vars, moment2_vars, beta1_pow_vars, beta2_pow_vars,
master_param_vars, 'epsilon', epsilon, 'beta1', beta1, 'beta2',
beta2, 'multi_precision', multi_precision)
outputs = {
'ParamOut': param_vars,
'Moment1Out': moment1_vars,
'Moment2Out': moment2_vars,
'Beta1PowOut': beta1_pow_vars,
'Beta2PowOut': beta2_pow_vars,
'MasterParamOut': master_param_vars
}
return outputs
class TestMergedAdam(unittest.TestCase):
def setUp(self):
paddle.disable_static()
self.shapes = [[3, 4], [2, 7], [5, 6], [7, 8]]
self.seed = 10
self.place = paddle.device.MLUPlace(0)
self.__class__.use_mlu = True
def gen_rand_data(self, shapes, dtype):
return [np.random.random(s).astype(dtype) for s in shapes]
def prepare_data(self, shapes, multi_precision, seed, place):
np.random.seed(seed)
mp_dtype = np.float32
# dtype = np.float16 if multi_precision and place == 'mlu' else np.float32
dtype = np.float32
params = self.gen_rand_data(shapes, dtype)
grads = self.gen_rand_data(shapes, dtype)
lrs = self.gen_rand_data([[1], [1], [1], [1]], mp_dtype)
moment1s = self.gen_rand_data(shapes, mp_dtype)
moment2s = self.gen_rand_data(shapes, mp_dtype)
beta1_pows = self.gen_rand_data([[1], [1], [1], [1]], mp_dtype)
beta2_pows = self.gen_rand_data([[1], [1], [1], [1]], mp_dtype)
master_params = [p.astype(mp_dtype) for p in params]
return params, grads, lrs, moment1s, moment2s, beta1_pows, beta2_pows, master_params
def check_with_place(self, place, multi_precision):
params, grads, lrs, moment1s, moment2s, beta1_pows, beta2_pows, master_params = self.prepare_data(
self.shapes, multi_precision, self.seed, place)
def run_op(use_merged):
return run_adam_op(params=params,
grads=grads,
lrs=lrs,
moment1s=moment1s,
moment2s=moment2s,
beta1_pows=beta1_pows,
beta2_pows=beta2_pows,
master_params=master_params,
epsilon=0.9,
beta1=0.9,
beta2=0.99,
place=place,
multi_precision=multi_precision,
use_merged=use_merged)
outs1 = run_op(True)
outs2 = run_op(False)
self.assertEqual(len(outs1), len(outs2))
for key in outs1.keys():
value1 = outs1[key]
value2 = outs2[key]
for i in range(len(value1)):
if place == 'mlu':
np.testing.assert_array_equal(value1[i], value2[i])
else:
np.testing.assert_allclose(value1[i],
value2[i],
rtol=1e-05,
atol=1e-07)
def test_main(self):
for multi_precision in [False, True]:
self.check_with_place(self.place, multi_precision)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册