未验证 提交 9e3e08f0 编写于 作者: R ronnywang 提交者: GitHub

[NPU] add momentum_op_npu and test (#34082)

* add momentum_op_npu and test

* update

* fix hang
上级 f6fab559
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/optimizers/momentum_op.h"
#include "paddle/fluid/operators/npu_op_runner.h"
#include "paddle/fluid/operators/optimizers/sgd_op.h"
namespace paddle {
namespace operators {
template <typename T>
class NPUMomentumOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& dev_ctx = ctx.template device_context<platform::NPUDeviceContext>();
std::string regularization_method =
ctx.Attr<std::string>("regularization_method");
auto regularization_coeff = ctx.Attr<float>("regularization_coeff");
RegularizationType regularization_flag{
RegularizationType::kNONE}; // disable regularization
if (regularization_method == "l2_decay") {
regularization_flag = RegularizationType::kL2DECAY;
}
T mu = static_cast<T>(ctx.Attr<float>("mu"));
bool use_nesterov = ctx.Attr<bool>("use_nesterov");
auto learning_rate = ctx.Input<framework::Tensor>("LearningRate");
auto param = ctx.Input<framework::Tensor>("Param");
auto velocity = ctx.Input<framework::Tensor>("Velocity");
auto param_out = ctx.Output<framework::Tensor>("ParamOut");
auto velocity_out = ctx.Output<framework::Tensor>("VelocityOut");
param_out->mutable_data<T>(ctx.GetPlace());
velocity_out->mutable_data<T>(ctx.GetPlace());
auto* grad_var = ctx.InputVar("Grad");
if (grad_var->IsType<framework::LoDTensor>()) {
auto grad = ctx.Input<framework::Tensor>("Grad");
Tensor mu_tensor;
mu_tensor.mutable_data<T>(framework::make_ddim({1}), ctx.GetPlace());
FillNpuTensorWithConstant<T>(&mu_tensor, mu);
Tensor regularized_grad;
if (regularization_flag == RegularizationType::kL2DECAY) {
regularized_grad.mutable_data<T>(grad->dims(), ctx.GetPlace());
const auto& runner1 = NpuOpRunner("Muls", {*param}, {regularized_grad},
{{"value", regularization_coeff}});
runner1.Run(dev_ctx.stream());
const auto& runner2 = NpuOpRunner("Add", {regularized_grad, *grad},
{regularized_grad}, {});
runner2.Run(dev_ctx.stream());
} else {
regularized_grad.ShareDataWith(*grad);
}
framework::TensorCopy(*param, ctx.GetPlace(), dev_ctx, param_out);
framework::TensorCopy(*velocity, ctx.GetPlace(), dev_ctx, velocity_out);
// NOTE: ApplyMomentum will change the input
const auto& runner = NpuOpRunner(
"ApplyMomentum", {*param_out, *velocity_out, *learning_rate,
regularized_grad, mu_tensor},
{*param_out}, {{"use_nesterov", use_nesterov}});
runner.Run(dev_ctx.stream());
} else if (grad_var->IsType<framework::SelectedRows>()) {
PADDLE_ENFORCE_EQ(false, true, platform::errors::PermissionDenied(
"Unsupport SparseMomentum"));
} else {
PADDLE_ENFORCE_EQ(false, true,
platform::errors::PermissionDenied(
"Unsupported Variable Type of Grad "
"in MomentumOp. Excepted LodTensor "
"or SelectedRows, But received [%s]",
paddle::framework::ToTypeName(grad_var->Type())));
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_NPU_KERNEL(momentum, ops::NPUMomentumOpKernel<float>,
ops::NPUMomentumOpKernel<plat::float16>);
......@@ -2217,7 +2217,14 @@ All parameter, weight, gradient are variables in Paddle.
#ifdef PADDLE_WITH_ASCEND_CL
m.def("get_npu_device_count", platform::GetNPUDeviceCount);
m.def("npu_finalize", []() { platform::AclInstance::Instance().Finalize(); });
m.def("npu_finalize", []() {
auto &pool = platform::DeviceContextPool::Instance();
auto devices = platform::GetSelectedNPUDevices();
for (size_t i = 0; i < devices.size(); ++i) {
pool.Get(platform::NPUPlace(devices[i]))->Wait();
}
platform::AclInstance::Instance().Finalize();
});
py::class_<platform::NPUProfConfigWrapper>(m, "NPUProfConfigWrapper");
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import unittest
import sys
sys.path.append("..")
from op_test import OpTest
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.op import Operator
from test_momentum_op import calculate_momentum_by_numpy
paddle.enable_static()
class TestMomentumOp1(OpTest):
def set_npu(self):
self.__class__.use_npu = True
def setUp(self):
self.set_npu()
self.op_type = "momentum"
self.init_dtype()
self.init_case()
param = np.random.random(self.shape).astype(self.dtype)
grad = np.random.random(self.shape).astype(self.dtype)
velocity = np.zeros(self.shape).astype(self.dtype)
learning_rate = np.array([0.001]).astype(np.float32)
mu = 0.0001
self.inputs = {
'Param': param,
'Grad': grad,
'Velocity': velocity,
'LearningRate': learning_rate
}
self.attrs = {'mu': mu, 'use_nesterov': self.use_nesterov}
param_out, velocity_out = calculate_momentum_by_numpy(
param=param,
grad=grad,
mu=mu,
velocity=velocity,
use_nesterov=self.use_nesterov,
learning_rate=learning_rate)
self.outputs = {'ParamOut': param_out, 'VelocityOut': velocity_out}
def init_case(self):
self.shape = (123, 321)
self.use_nesterov = False
def init_dtype(self):
self.dtype = np.float32
def test_check_output(self):
self.check_output_with_place(core.NPUPlace(0))
class TestMomentumOpFp16(TestMomentumOp1):
def init_dtype(self):
self.dtype = np.float16
def test_check_output(self):
self.check_output(atol=1e-3)
class TestMomentumOp2(TestMomentumOp1):
def init_case(self):
self.shape = (123, 321)
self.use_nesterov = True
class TestMomentumV2(unittest.TestCase):
def test_momentum_dygraph(self):
paddle.disable_static(place=fluid.NPUPlace(0))
value = np.arange(26).reshape(2, 13).astype("float32")
a = paddle.to_tensor(value)
linear = paddle.nn.Linear(13, 5)
# This can be any optimizer supported by dygraph.
adam = paddle.optimizer.Momentum(
learning_rate=0.01, momentum=0.9, parameters=linear.parameters())
out = linear(a)
out.backward()
adam.step()
adam.clear_gradients()
def test_momentum(self):
paddle.enable_static()
place = fluid.NPUPlace(0)
main = fluid.Program()
with fluid.program_guard(main):
x = fluid.layers.data(name='x', shape=[13], dtype='float32')
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
y_predict = fluid.layers.fc(input=x, size=1, act=None)
cost = fluid.layers.square_error_cost(input=y_predict, label=y)
avg_cost = fluid.layers.mean(cost)
rms_optimizer = paddle.optimizer.Momentum(
learning_rate=0.1, momentum=0.9)
rms_optimizer.minimize(avg_cost)
fetch_list = [avg_cost]
train_reader = paddle.batch(
paddle.dataset.uci_housing.train(), batch_size=1)
feeder = fluid.DataFeeder(place=place, feed_list=[x, y])
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
for data in train_reader():
exe.run(main, feed=feeder.feed(data), fetch_list=fetch_list)
def test_raise_error(self):
self.assertRaises(
ValueError, paddle.optimizer.Momentum, learning_rate=None)
self.assertRaises(ValueError, paddle.optimizer.Momentum, momentum=None)
class TestMomentumOpWithDecay(OpTest):
def set_npu(self):
self.__class__.use_npu = True
def setUp(self):
self.set_npu()
self.op_type = "momentum"
self.dtype = np.float32
self.use_nesterov = True
self.regularization_method = 'l2_decay'
self.regularization_coeff = 0.9
self.init_config()
param = np.random.random((123, 321)).astype(self.dtype)
grad = np.random.random((123, 321)).astype(self.dtype)
velocity = np.zeros((123, 321)).astype(self.dtype)
learning_rate = np.array([0.001]).astype(np.float32)
mu = 0.0001
use_nesterov = self.use_nesterov
regularization_method = self.regularization_method
regularization_coeff = self.regularization_coeff
self.inputs = {
'Param': param,
'Grad': grad,
'Velocity': velocity,
'LearningRate': learning_rate
}
self.attrs = {
'mu': mu,
'use_nesterov': use_nesterov,
'regularization_method': regularization_method,
'regularization_coeff': regularization_coeff
}
grad = grad + regularization_coeff * param
param_out, velocity_out = calculate_momentum_by_numpy(
param=param,
grad=grad,
mu=mu,
velocity=velocity,
use_nesterov=use_nesterov,
learning_rate=learning_rate)
self.outputs = {'ParamOut': param_out, 'VelocityOut': velocity_out}
def init_config(self):
pass
def test_check_output(self):
paddle.enable_static()
self.check_output_with_place(core.NPUPlace(0), atol=3e-3)
class TestMomentumOpWithDecayFP16(TestMomentumOpWithDecay):
def init_config(self):
self.dtype = np.float16
def test_check_output(self):
paddle.enable_static()
self.check_output(atol=1e-3)
class TestMomentumOpWithDecay2(TestMomentumOpWithDecay):
def init_config(self):
self.use_nesterov = False
class TestMomentumOpWithDecayAPI(unittest.TestCase):
def _test_momentum_dygraph_common(self, regularization):
paddle.disable_static(fluid.NPUPlace(0))
inp = np.random.uniform(-0.1, 0.1, [10, 10]).astype("float32")
linear = paddle.nn.Linear(10, 10)
inp = paddle.to_tensor(inp)
out = linear(inp)
loss = paddle.mean(out)
# This can be any optimizer supported by dygraph.
momentum = paddle.fluid.contrib.optimizer.Momentum(
learning_rate=0.01,
momentum=0.9,
parameter_list=linear.parameters(),
regularization=regularization)
momentum.minimize(loss)
def test_momentum_dygraph_1(self):
self._test_momentum_dygraph_common(
regularization=paddle.fluid.regularizer.L2Decay(
regularization_coeff=0.1))
def test_momentum_static(self):
paddle.enable_static()
place = fluid.NPUPlace(0)
main = fluid.Program()
with fluid.program_guard(main):
x = fluid.layers.data(name='x', shape=[13], dtype='float32')
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
y_predict = fluid.layers.fc(input=x, size=1, act=None)
cost = fluid.layers.square_error_cost(input=y_predict, label=y)
avg_cost = fluid.layers.mean(cost)
momentum_optimizer = paddle.fluid.contrib.optimizer.Momentum(
learning_rate=0.1, momentum=0.9)
momentum_optimizer.minimize(avg_cost)
fetch_list = [avg_cost]
train_reader = paddle.batch(
paddle.dataset.uci_housing.train(), batch_size=1)
feeder = fluid.DataFeeder(place=place, feed_list=[x, y])
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
for data in train_reader():
exe.run(main, feed=feeder.feed(data), fetch_list=fetch_list)
class TestMomentumOpVsMomentumOpWithDecayAPI(unittest.TestCase):
def __update_params(self, momentum, linear):
for i in range(10):
inp = paddle.full(
shape=[2, 2], fill_value=i, dtype='float32').astype("float32")
inp = paddle.to_tensor(inp)
out = linear(inp)
loss = paddle.mean(out)
loss.backward()
momentum.minimize(loss)
linear.clear_gradients()
def __test_vs(self, place=fluid.NPUPlace(0)):
paddle.disable_static(place=place)
linear_old = paddle.nn.Linear(
2,
2,
weight_attr=paddle.nn.initializer.Constant(value=2.0),
bias_attr=paddle.nn.initializer.Constant(value=2.0))
momentum_old = paddle.fluid.optimizer.Momentum(
learning_rate=0.01,
momentum=0.9,
parameter_list=linear_old.parameters(),
regularization=paddle.fluid.regularizer.L2Decay(
regularization_coeff=0.1))
self.__update_params(momentum=momentum_old, linear=linear_old)
linear_new = paddle.nn.Linear(
2,
2,
weight_attr=paddle.nn.initializer.Constant(value=2.0),
bias_attr=paddle.nn.initializer.Constant(value=2.0))
momentum_new = paddle.fluid.contrib.optimizer.Momentum(
learning_rate=0.01,
momentum=0.9,
parameter_list=linear_new.parameters(),
regularization=paddle.fluid.regularizer.L2Decay(
regularization_coeff=0.1))
self.__update_params(momentum=momentum_new, linear=linear_new)
self.assertEqual(
(linear_old.weight.numpy() == linear_new.weight.numpy()).all(),
True,
'the param weight updated by two Momentum optimizers should equal')
def test_vs(self, place=fluid.NPUPlace(0)):
self.__test_vs(place=place)
class TestMomentumV2Group(TestMomentumV2):
def test_momentum_dygraph(self):
paddle.disable_static(place=fluid.NPUPlace(0))
value = np.arange(26).reshape(2, 13).astype("float32")
a = paddle.to_tensor(value)
linear_1 = paddle.nn.Linear(13, 5)
linear_2 = paddle.nn.Linear(5, 3)
# This can be any optimizer supported by dygraph.
adam = paddle.optimizer.Momentum(
learning_rate=0.01,
parameters=[{
'params': linear_1.parameters()
}, {
'params': linear_2.parameters(),
'weight_decay': 0.001,
'learning_rate': 0.1,
'momentum': 0.99
}],
weight_decay=0.1,
momentum=0.9)
out = linear_1(a)
out = linear_2(out)
out.backward()
adam.step()
adam.clear_gradients()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册