未验证 提交 533c649f 编写于 作者: H houj04 提交者: GitHub

momentum support l2decay for xpu. test=kunlun (#41325)

* momentum support l2decay for xpu. test=kunlun

* fix include file. test=kunlun

* fix cmake for device_worker. test=kunlun
上级 56e72b20
...@@ -36,7 +36,7 @@ ENDIF() ...@@ -36,7 +36,7 @@ ENDIF()
if(NOT DEFINED XPU_BASE_URL) if(NOT DEFINED XPU_BASE_URL)
SET(XPU_BASE_URL_WITHOUT_DATE "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev") SET(XPU_BASE_URL_WITHOUT_DATE "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev")
SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220331") SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220402")
else() else()
SET(XPU_BASE_URL "${XPU_BASE_URL}") SET(XPU_BASE_URL "${XPU_BASE_URL}")
endif() endif()
......
...@@ -117,12 +117,14 @@ endif() ...@@ -117,12 +117,14 @@ endif()
cc_test(var_type_traits_test SRCS var_type_traits_test.cc DEPS var_type_traits) cc_test(var_type_traits_test SRCS var_type_traits_test.cc DEPS var_type_traits)
set(BRPC_DEPS "") set(BRPC_DEPS "")
if(WITH_PSLIB OR WITH_PSCORE) if(WITH_PSCORE)
if(NOT WITH_HETERPS) set(BRPC_DEPS brpc ssl crypto)
set(BRPC_DEPS brpc ssl crypto) endif()
endif() if(WITH_PSLIB)
if(WITH_PSLIB_BRPC) if(WITH_PSLIB_BRPC)
set(BRPC_DEPS pslib_brpc) set(BRPC_DEPS pslib_brpc)
elseif(NOT WITH_HETERPS)
set(BRPC_DEPS brpc ssl crypto)
endif() endif()
endif() endif()
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
#include <string> #include <string>
#include "paddle/fluid/operators/optimizers/sgd_op.h" #include "paddle/fluid/operators/optimizers/sgd_op.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -33,6 +34,13 @@ class MomentumOpXPUKernel : public framework::OpKernel<T> { ...@@ -33,6 +34,13 @@ class MomentumOpXPUKernel : public framework::OpKernel<T> {
velocity_out->mutable_data<T>(ctx.GetPlace()); velocity_out->mutable_data<T>(ctx.GetPlace());
auto* lr = learning_rate->data<T>(); auto* lr = learning_rate->data<T>();
auto regularization_method = ctx.Attr<std::string>("regularization_method");
auto regularization_coeff = ctx.Attr<float>("regularization_coeff");
if (regularization_method != "l2_decay") {
// only support l2_decay
regularization_coeff = 0.0f;
}
auto* grad_var = ctx.InputVar("Grad"); auto* grad_var = ctx.InputVar("Grad");
PADDLE_ENFORCE_EQ(grad_var->IsType<framework::LoDTensor>(), true, PADDLE_ENFORCE_EQ(grad_var->IsType<framework::LoDTensor>(), true,
platform::errors::PermissionDenied( platform::errors::PermissionDenied(
...@@ -44,28 +52,16 @@ class MomentumOpXPUKernel : public framework::OpKernel<T> { ...@@ -44,28 +52,16 @@ class MomentumOpXPUKernel : public framework::OpKernel<T> {
auto grad = ctx.Input<framework::Tensor>("Grad"); auto grad = ctx.Input<framework::Tensor>("Grad");
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
// int momentum(Context* ctx, const T* param, const T* velocity, const T*
// grad, T* param_out, T* velocity_out, int len, const float* lr, int
// use_nesterov, float mu, float l2_weight_decay);
int r = xpu::momentum(dev_ctx.x_context(), param->data<float>(), int r = xpu::momentum(dev_ctx.x_context(), param->data<float>(),
velocity->data<float>(), grad->data<float>(), velocity->data<float>(), grad->data<float>(),
param_out->data<float>(), velocity_out->data<float>(), param_out->data<float>(), velocity_out->data<float>(),
param_out->numel(), lr, use_nesterov, mu); param_out->numel(), lr, use_nesterov, mu,
if (r == xpu::Error_t::INVALID_PARAM) { regularization_coeff);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_XDNN_SUCCESS(r, "momentum");
r, xpu::Error_t::SUCCESS,
platform::errors::InvalidArgument(
"XPU kernel error of MomentumOp, error message: INVALID_PARAM, "
"please check your input & output."));
} else if (r == xpu::Error_t::RUNTIME_ERROR) {
PADDLE_ENFORCE_EQ(
r, xpu::Error_t::SUCCESS,
platform::errors::Unavailable(
"XPU kernel error of MomentumOp, error message: RUNTIME_ERROR, "
"please check whether Baidu Kunlun card is properly installed."));
} else if (r == xpu::Error_t::NO_ENOUGH_WORKSPACE) {
PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS,
platform::errors::ResourceExhausted(
"XPU kernel error of MomentumOp, error message: "
"NO_ENOUGH_WORKSPACE, XPU has no enough memory."));
}
} }
}; };
} // namespace operators } // namespace operators
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -17,52 +17,150 @@ from __future__ import print_function ...@@ -17,52 +17,150 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
import sys import sys
import os
sys.path.append("..") sys.path.append("..")
from op_test import OpTest
import paddle
from paddle.fluid import core
from paddle.fluid.op import Operator
class TestMomentumOp1(OpTest): import paddle
def setUp(self): import paddle.fluid.core as core
self.op_type = "momentum"
self.dtype = np.float32
self.init_dtype()
param = np.random.random((123, 321)).astype(self.dtype) from op_test import OpTest
grad = np.random.random((123, 321)).astype(self.dtype) from op_test_xpu import XPUOpTest
velocity = np.zeros((123, 321)).astype(self.dtype) from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper
learning_rate = np.array([0.001]).astype(self.dtype)
mu = 0.0001
use_nesterov = False
self.inputs = { paddle.enable_static()
'Param': param,
'Grad': grad,
'Velocity': velocity,
'LearningRate': learning_rate
}
self.attrs = {'mu': mu}
def calculate_momentum_by_numpy(param, grad, mu, velocity, use_nesterov,
learning_rate, regularization_method,
regularization_coeff):
if regularization_method == "l2_decay":
grad = grad + regularization_coeff * param
velocity_out = mu * velocity + grad
if use_nesterov:
param_out = param - (grad + velocity_out * mu) * learning_rate
else:
param_out = param - learning_rate * velocity_out
else:
velocity_out = mu * velocity + grad velocity_out = mu * velocity + grad
if use_nesterov: if use_nesterov:
param_out = param - grad * learning_rate - \ param_out = param - grad * learning_rate - \
velocity_out * mu * learning_rate velocity_out * mu * learning_rate
else: else:
param_out = param - learning_rate * velocity_out param_out = param - learning_rate * velocity_out
return param_out, velocity_out
class XPUTestMomentumOP(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'momentum'
self.use_dynamic_create_class = False
class TestMomentumOPBase(XPUOpTest):
def setUp(self):
self.place = paddle.XPUPlace(0)
self.xpu_version = core.get_xpu_device_version(0)
self.init_dtype()
self.set_case()
def set_case(self):
self.op_type = 'momentum'
self.dtype = self.in_type
self.init_config()
self.param = np.random.uniform(-1, 1,
self.input_shape).astype(self.dtype)
self.grad = np.random.uniform(-1, 1,
self.input_shape).astype(self.dtype)
self.velocity = np.random.uniform(
-1, 1, self.input_shape).astype(self.dtype)
param_out, velocity_out = calculate_momentum_by_numpy(
param=self.param,
grad=self.grad,
mu=self.mu,
velocity=self.velocity,
use_nesterov=self.use_nesterov,
learning_rate=self.learning_rate,
regularization_method=self.regularization_method,
regularization_coeff=self.regularization_coeff)
self.inputs = {
'Param': self.param,
'Grad': self.grad,
'Velocity': self.velocity,
'LearningRate': self.learning_rate,
}
self.attrs = {
'use_xpu': True,
'mu': self.mu,
'use_nesterov': self.use_nesterov,
'regularization_method': self.regularization_method,
'regularization_coeff': self.regularization_coeff
}
self.outputs = {'ParamOut': param_out, 'VelocityOut': velocity_out}
def init_dtype(self):
self.dtype = np.float32
def test_check_output(self):
self.check_output_with_place(self.place)
def init_config(self):
self.input_shape = [864]
self.learning_rate = np.array([0.001]).astype(self.dtype)
self.mu = 0.0001
self.use_nesterov = False
self.regularization_method = None
self.regularization_coeff = 0
class XPUTestMomentum1(TestMomentumOPBase):
def init_config(self):
self.input_shape = [2, 768]
self.learning_rate = np.array([0.002]).astype(self.dtype)
self.mu = 0.001
self.use_nesterov = False
self.regularization_method = None
self.regularization_coeff = 0
class XPUTestMomentum2(TestMomentumOPBase):
def init_config(self):
self.input_shape = [3, 8, 4096]
self.learning_rate = np.array([0.005]).astype(self.dtype)
self.mu = 0.002
self.use_nesterov = True
self.regularization_method = None
self.regularization_coeff = 0
self.outputs = {'ParamOut': param_out, 'VelocityOut': velocity_out} class XPUTestMomentum3(TestMomentumOPBase):
def init_config(self):
self.input_shape = [1024]
self.learning_rate = np.array([0.01]).astype(self.dtype)
self.mu = 0.0001
self.use_nesterov = False
if self.xpu_version != core.XPUVersion.XPU1:
self.regularization_method = "l2_decay"
self.regularization_coeff = 0.005
else:
# regularization not supported on XPU1
self.regularization_method = None
self.regularization_coeff = 0
def init_dtype(self): class XPUTestMomentum4(TestMomentumOPBase):
pass def init_config(self):
self.input_shape = [2, 2, 255]
self.learning_rate = np.array([0.0005]).astype(self.dtype)
self.mu = 0.005
self.use_nesterov = True
if self.xpu_version != core.XPUVersion.XPU1:
self.regularization_method = "l2_decay"
self.regularization_coeff = 0.005
else:
# regularization not supported on XPU1
self.regularization_method = None
self.regularization_coeff = 0
def test_check_output_with_place(self):
self.check_output_with_place(paddle.XPUPlace(0))
support_types = get_xpu_op_support_types('momentum')
for stype in support_types:
create_test_class(globals(), XPUTestMomentumOP, stype)
if __name__ == "__main__": if __name__ == "__main__":
paddle.enable_static()
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册