未验证 提交 1a0ef45e 编写于 作者: T taixiurong 提交者: GitHub

xpu-paddlepaddle-37 [任务] 迁移lamb到phi (#45520)

test=kunlun
上级 5696f967
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "gflags/gflags.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
#ifdef PADDLE_WITH_XPU
template <typename DeviceContext, typename T>
class LambOpXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using paddle::framework::LoDTensor;
const auto* param_var = ctx.InputVar("Param");
PADDLE_ENFORCE_EQ(param_var->IsType<framework::LoDTensor>(),
true,
platform::errors::InvalidArgument(
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.InputNames("Param").front(),
framework::ToTypeName(param_var->Type())));
using paddle::framework::LoDTensor;
// inputs
T epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
T weight_decay = static_cast<T>(ctx.Attr<float>("weight_decay"));
T beta1 = static_cast<T>(ctx.Attr<float>("beta1"));
T beta2 = static_cast<T>(ctx.Attr<float>("beta2"));
auto& param = GET_DATA_SAFELY(
ctx.Input<LoDTensor>("Param"), "Input", "Param", "Lamb");
auto* grad_var = ctx.InputVar("Grad");
auto& mom1 = GET_DATA_SAFELY(
ctx.Input<LoDTensor>("Moment1"), "Input", "Moment1", "Lamb");
auto& mom2 = GET_DATA_SAFELY(
ctx.Input<LoDTensor>("Moment2"), "Input", "Moment2", "Lamb");
auto& lr = GET_DATA_SAFELY(
ctx.Input<LoDTensor>("LearningRate"), "Input", "LearningRate", "Lamb");
auto& beta1_pow = GET_DATA_SAFELY(
ctx.Input<LoDTensor>("Beta1Pow"), "Input", "Beta1Pow", "Lamb");
auto& beta2_pow = GET_DATA_SAFELY(
ctx.Input<LoDTensor>("Beta2Pow"), "Input", "Beta2Pow", "Lamb");
auto& param_out = GET_DATA_SAFELY(
ctx.Output<LoDTensor>("ParamOut"), "Output", "ParamOut", "Lamb");
auto& mom1_out = GET_DATA_SAFELY(
ctx.Output<LoDTensor>("Moment1Out"), "Output", "Moment1Out", "Lamb");
auto& mom2_out = GET_DATA_SAFELY(
ctx.Output<LoDTensor>("Moment2Out"), "Output", "Moment2Out", "Lamb");
auto& beta1_pow_out = GET_DATA_SAFELY(
ctx.Output<LoDTensor>("Beta1PowOut"), "Output", "Beta1PowOut", "Lamb");
auto& beta2_pow_out = GET_DATA_SAFELY(
ctx.Output<LoDTensor>("Beta2PowOut"), "Output", "Beta2PowOut", "Lamb");
auto& dev_ctx = ctx.template device_context<DeviceContext>();
if (grad_var->IsType<framework::LoDTensor>()) {
auto& grad = *ctx.Input<LoDTensor>("Grad");
int r = xpu::lamb(dev_ctx.x_context(),
grad.template data<T>(),
mom1.template data<T>(),
mom2.template data<T>(),
param.template data<T>(),
beta1_pow.template data<T>(),
beta2_pow.template data<T>(),
mom1_out.template mutable_data<T>(ctx.GetPlace()),
mom2_out.template mutable_data<T>(ctx.GetPlace()),
param_out.template mutable_data<T>(ctx.GetPlace()),
beta1_pow_out.template mutable_data<T>(ctx.GetPlace()),
beta2_pow_out.template mutable_data<T>(ctx.GetPlace()),
beta1,
beta2,
epsilon,
weight_decay,
lr.template data<T>(),
param.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "lamb");
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Variable type not supported by lamb_op. Expect LoDTensor, "
"but got %s",
framework::ToTypeName(param_var->Type())));
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(
lamb, ops::LambOpXPUKernel<paddle::platform::XPUDeviceContext, float>);
#endif
...@@ -306,6 +306,9 @@ XPUOpMap& get_kl2_ops() { ...@@ -306,6 +306,9 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"label_smooth", {"label_smooth",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"lamb",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"lars_momentum", {"lars_momentum",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})}, pOpKernelType(vartype::FP16, XPUPlace())})},
...@@ -650,8 +653,7 @@ XPUOpMap& get_kl2_ops() { ...@@ -650,8 +653,7 @@ XPUOpMap& get_kl2_ops() {
{"resnet_basic_block_grad", {"resnet_basic_block_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"resnet_basic_block", {"resnet_basic_block",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}};
};
return s_xpu2_kernels; return s_xpu2_kernels;
} }
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/lamb_kernel.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void LambKernel(const Context& dev_ctx,
const DenseTensor& param,
const DenseTensor& grad,
const DenseTensor& learning_rate,
const DenseTensor& moment1,
const DenseTensor& moment2,
const DenseTensor& beta1_pow,
const DenseTensor& beta2_pow,
const paddle::optional<DenseTensor>& master_param,
const paddle::optional<DenseTensor>& skip_update,
float weight_decay,
float beta1,
float beta2,
float epsilon,
bool multi_precision,
DenseTensor* param_outs,
DenseTensor* moment1_out,
DenseTensor* moment2_out,
DenseTensor* beta1_pow_out,
DenseTensor* beta2_pow_out,
DenseTensor* master_param_outs) {
using XPUType = typename XPUTypeTrait<T>::Type;
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
if (!multi_precision) {
constexpr auto kIsSameType = std::is_same<T, MT>::value;
PADDLE_ENFORCE_EQ(
kIsSameType,
true,
phi::errors::InvalidArgument(
"When multi_precision=False, T and MT must be the same type."));
}
bool cpu_skip_update = false;
if (skip_update && skip_update->IsInitialized()) {
if (paddle::platform::is_cpu_place(skip_update->place())) {
cpu_skip_update = *(skip_update->data<bool>());
} else {
const bool* skip_update_flag = skip_update->data<bool>();
paddle::memory::Copy(phi::CPUPlace(),
static_cast<void*>(&cpu_skip_update),
dev_ctx.GetPlace(),
static_cast<const void*>(skip_update_flag),
sizeof(bool));
}
}
if (cpu_skip_update) {
return;
}
// tensor --> data_ptr
// inputs
const XPUType* param_ptr = reinterpret_cast<const XPUType*>(param.data<T>());
const XPUType* grad_ptr = reinterpret_cast<const XPUType*>(grad.data<T>());
const MT* learning_rate_ptr = learning_rate.data<MT>();
const MT* moment1_ptr = moment1.data<MT>();
const MT* moment2_ptr = moment2.data<MT>();
const MT* beta1_pow_ptr = beta1_pow.data<MT>();
const MT* beta2_pow_ptr = beta2_pow.data<MT>();
const MT* master_param_ptr = nullptr;
if (multi_precision) {
master_param_ptr = master_param.get_ptr()->data<MT>();
}
// outputs
XPUType* param_outs_ptr =
reinterpret_cast<XPUType*>(dev_ctx.template Alloc<T>(param_outs));
MT* moment1_out_ptr = dev_ctx.template Alloc<MT>(moment1_out);
MT* moment2_out_ptr = dev_ctx.template Alloc<MT>(moment2_out);
MT* master_param_outs_ptr = nullptr;
if (multi_precision) {
if (master_param_outs->numel() != master_param.get_ptr()->numel()) {
master_param_outs->Resize(master_param.get_ptr()->dims());
}
master_param_outs_ptr = dev_ctx.template Alloc<MT>(master_param_outs);
}
MT* beta1_pow_out_ptr = nullptr;
MT* beta2_pow_out_ptr = nullptr;
MT* beta1_pow_xpu_ptr = nullptr;
MT* beta2_pow_xpu_ptr = nullptr;
xpu::Context* xpu_ctx = dev_ctx.x_context();
xpu::ctx_guard RAII_GUARD(xpu_ctx);
if (beta1_pow.place().GetType() == phi::AllocationType::CPU) {
int r = xpu_malloc(reinterpret_cast<void**>(&beta1_pow_xpu_ptr),
(beta1_pow.numel()) * sizeof(MT));
PADDLE_ENFORCE_XPU_SUCCESS(r);
paddle::memory::Copy(dev_ctx.GetPlace(),
beta1_pow_xpu_ptr,
beta1_pow.place(),
beta1_pow.data<MT>(),
sizeof(MT) * beta1_pow.numel());
beta1_pow_ptr = beta1_pow_xpu_ptr;
beta1_pow_out_ptr = RAII_GUARD.alloc_l3_or_gm<MT>(beta1_pow_out->numel());
PADDLE_ENFORCE_XDNN_NOT_NULL(beta1_pow_out_ptr);
} else {
beta1_pow_out_ptr = dev_ctx.template Alloc<MT>(beta1_pow_out);
}
if (beta2_pow.place().GetType() == phi::AllocationType::CPU) {
int r = xpu_malloc(reinterpret_cast<void**>(&beta2_pow_xpu_ptr),
(beta2_pow.numel()) * sizeof(MT));
PADDLE_ENFORCE_XPU_SUCCESS(r);
paddle::memory::Copy(dev_ctx.GetPlace(),
beta2_pow_xpu_ptr,
beta2_pow.place(),
beta2_pow.data<MT>(),
sizeof(MT) * beta2_pow.numel());
beta2_pow_ptr = beta2_pow_xpu_ptr;
beta2_pow_out_ptr = RAII_GUARD.alloc_l3_or_gm<MT>(beta2_pow_out->numel());
PADDLE_ENFORCE_XDNN_NOT_NULL(beta2_pow_out_ptr);
} else {
beta2_pow_out_ptr = dev_ctx.template Alloc<MT>(beta2_pow_out);
}
const MT* param_calc_ptr = nullptr;
const MT* grad_calc_ptr = nullptr;
MT* param_outs_calc_ptr = nullptr;
if (std::is_same<T, phi::dtype::float16>::value) {
MT* param_float = RAII_GUARD.alloc_l3_or_gm<MT>(param.numel());
PADDLE_ENFORCE_XDNN_NOT_NULL(param_float);
MT* grad_float = RAII_GUARD.alloc_l3_or_gm<MT>(grad.numel());
PADDLE_ENFORCE_XDNN_NOT_NULL(grad_float);
MT* param_outs_float = RAII_GUARD.alloc_l3_or_gm<MT>(param_outs->numel());
PADDLE_ENFORCE_XDNN_NOT_NULL(param_outs_float);
int r =
xpu::cast<XPUType, MT>(xpu_ctx, param_ptr, param_float, param.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
r = xpu::cast<XPUType, MT>(xpu_ctx, grad_ptr, grad_float, grad.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
param_calc_ptr = param_float;
grad_calc_ptr = grad_float;
param_outs_calc_ptr = param_outs_float;
} else {
param_calc_ptr = reinterpret_cast<const MT*>(param_ptr);
grad_calc_ptr = reinterpret_cast<const MT*>(grad_ptr);
param_outs_calc_ptr = reinterpret_cast<MT*>(param_outs_ptr);
}
int r = xpu::lamb<MT>(
xpu_ctx,
grad_calc_ptr,
moment1_ptr,
moment2_ptr,
(multi_precision ? master_param_ptr : param_calc_ptr),
beta1_pow_ptr,
beta2_pow_ptr,
moment1_out_ptr,
moment2_out_ptr,
(multi_precision ? master_param_outs_ptr : param_outs_calc_ptr),
beta1_pow_out_ptr,
beta2_pow_out_ptr,
beta1,
beta2,
epsilon,
weight_decay,
learning_rate_ptr,
param.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "lamb");
if (std::is_same<T, phi::dtype::float16>::value && multi_precision == false) {
int r = xpu::cast<MT, XPUType>(
xpu_ctx, param_outs_calc_ptr, param_outs_ptr, param_outs->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
}
if (beta1_pow.place().GetType() == phi::AllocationType::CPU) {
// copy beta1_pow_out from xpu to cpu
paddle::memory::Copy(beta1_pow.place(),
dev_ctx.template HostAlloc<MT>(beta1_pow_out),
dev_ctx.GetPlace(),
beta1_pow_out_ptr,
sizeof(MT) * beta1_pow_out->numel());
if (beta1_pow_xpu_ptr) {
xpu_free(beta1_pow_xpu_ptr);
}
}
if (beta2_pow.place().GetType() == phi::AllocationType::CPU) {
// copy beta2_pow_out from xpu to cpu
paddle::memory::Copy(beta2_pow.place(),
dev_ctx.template HostAlloc<MT>(beta2_pow_out),
dev_ctx.GetPlace(),
beta2_pow_out_ptr,
sizeof(MT) * beta2_pow_out->numel());
if (beta2_pow_xpu_ptr) {
xpu_free(beta2_pow_xpu_ptr);
}
}
}
} // namespace phi
PD_REGISTER_KERNEL(
lamb, XPU, ALL_LAYOUT, phi::LambKernel, float, phi::dtype::float16) {
kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND);
}
...@@ -88,6 +88,7 @@ xpu_test_op_type_white_list = [ ...@@ -88,6 +88,7 @@ xpu_test_op_type_white_list = [
'dropout_float16', 'dropout_float16',
'dropout_grad_float16', 'dropout_grad_float16',
"grad_add_float32", # no api for grad_add, skip "grad_add_float32", # no api for grad_add, skip
"lamb_float16",
"lars_momentum_float32", "lars_momentum_float32",
"resnet_unit", "resnet_unit",
"resnet_unit_grad" "resnet_unit_grad"
......
...@@ -23,27 +23,82 @@ from paddle.fluid import core ...@@ -23,27 +23,82 @@ from paddle.fluid import core
from paddle.fluid.op import Operator from paddle.fluid.op import Operator
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle import paddle
"""
class TestLambOp1(XPUOpTest): from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper
def lamb_step(inputs, attributes):
'''
Simulate one step of the lamb optimizer
:param inputs: dict of inputs
:param attributes: dict of attributes
:return tuple: tuple of output param, moment1, moment2,
beta1 power accumulator and beta2 power accumulator
'''
param = inputs['Param']
grad = inputs['Grad']
moment1 = inputs['Moment1']
moment2 = inputs['Moment2']
lr = inputs['LearningRate']
beta1_pow = inputs['Beta1Pow']
beta2_pow = inputs['Beta2Pow']
beta1 = attributes['beta1']
beta2 = attributes['beta2']
epsilon = attributes['epsilon']
weight_decay = attributes['weight_decay']
moment1_out = beta1 * moment1 + (1 - beta1) * grad
moment2_out = beta2 * moment2 + (1 - beta2) * np.square(grad)
moment1_unbiased = moment1_out / (1 - beta1_pow)
moment2_unbiased = moment2_out / (1 - beta2_pow)
r_1 = np.linalg.norm(param)
r_2 = np.linalg.norm(moment1_unbiased /
(np.sqrt(moment2_unbiased) + epsilon) +
weight_decay * param)
lr_t = lr * r_1 / r_2
param_out = param - lr_t * (moment1_unbiased /
(np.sqrt(moment2_unbiased) + epsilon) +
weight_decay * param)
beta1_pow_out = beta1_pow * beta1
beta2_pow_out = beta2_pow * beta2
return param_out, moment1_out, moment2_out, beta1_pow_out, beta2_pow_out
class XPUTestLambOp(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'lamb'
self.use_dynamic_create_class = False
class TestLambOp1(XPUOpTest):
def set_attrs(self): def set_attrs(self):
self.attrs = { self.attrs = {
'epsilon': 1e-6, 'epsilon': 1e-4,
'beta1': 0.9, 'beta1': 0.78,
'beta2': 0.999, 'beta2': 0.836,
'weight_decay': 0.01 'weight_decay': 0.01
} }
def setUp(self): def setUp(self):
'''Test Lamb Op with supplied attributes '''Test Lamb Op with supplied attributes
''' '''
self.op_type = 'lamb' # self.op_type = self.op_name
param = np.random.uniform(-1, 1, 5000).astype('float32') self.__class__.op_type = 'lamb'
grad = np.random.uniform(-1, 1, 5000).astype('float32') self.dtype = self.in_type
moment1 = np.random.uniform(-1, 1, 5000).astype('float32') param = np.random.uniform(-1, 1, (102, 105)).astype(self.dtype)
moment2 = np.random.random(5000).astype('float32') grad = np.random.uniform(-1, 1, (102, 105)).astype(self.dtype)
moment1 = np.random.uniform(-1, 1, (102, 105)).astype("float32")
moment2 = np.random.random((102, 105)).astype("float32")
self.set_attrs()
learning_rate = 0.001 learning_rate = 0.001
self.set_attrs()
beta1_pow = self.attrs['beta1'] beta1_pow = self.attrs['beta1']
beta2_pow = self.attrs['beta2'] beta2_pow = self.attrs['beta2']
...@@ -52,11 +107,12 @@ class TestLambOp1(XPUOpTest): ...@@ -52,11 +107,12 @@ class TestLambOp1(XPUOpTest):
'Grad': grad, 'Grad': grad,
'Moment1': moment1, 'Moment1': moment1,
'Moment2': moment2, 'Moment2': moment2,
'LearningRate': np.array([learning_rate]).astype('float32'), 'LearningRate': np.array([learning_rate]).astype("float32"),
'Beta1Pow': np.array([beta1_pow]).astype('float32'), 'Beta1Pow': np.array([beta1_pow]).astype("float32"),
'Beta2Pow': np.array([beta2_pow]).astype('float32') 'Beta2Pow': np.array([beta2_pow]).astype("float32")
} }
param_out, moment1_out, moment2_out, \ param_out, moment1_out, moment2_out, \
beta1_pow_out, beta2_pow_out = lamb_step(self.inputs, self.attrs) beta1_pow_out, beta2_pow_out = lamb_step(self.inputs, self.attrs)
...@@ -71,50 +127,60 @@ class TestLambOp1(XPUOpTest): ...@@ -71,50 +127,60 @@ class TestLambOp1(XPUOpTest):
def test_check_output(self): def test_check_output(self):
self.check_output_with_place(paddle.XPUPlace(0)) self.check_output_with_place(paddle.XPUPlace(0))
class TestLambOp2(TestLambOp1):
def lamb_step(inputs, attributes): def set_attrs(self):
''' self.attrs = {
Simulate one step of the lamb optimizer 'epsilon': 1e-8,
:param inputs: dict of inputs 'beta1': 0.9,
:param attributes: dict of attributes 'beta2': 0.999,
:return tuple: tuple of output param, moment1, moment2, 'weight_decay': 0.01
beta1 power accumulator and beta2 power accumulator }
'''
param = inputs['Param']
grad = inputs['Grad']
moment1 = inputs['Moment1']
moment2 = inputs['Moment2']
lr = inputs['LearningRate']
beta1_pow = inputs['Beta1Pow']
beta2_pow = inputs['Beta2Pow']
beta1 = attributes['beta1'] class TestLambOpMultipleSteps(TestLambOp1):
beta2 = attributes['beta2']
epsilon = attributes['epsilon']
weight_decay = attributes['weight_decay']
moment1_out = beta1 * moment1 + (1 - beta1) * grad def set_attrs(self):
moment2_out = beta2 * moment2 + (1 - beta2) * np.square(grad) self.attrs = {
'epsilon': 1e-8,
'beta1': 0.9,
'beta2': 0.999,
'weight_decay': 0.01
}
self.num_steps = 10
moment1_unbiased = moment1_out / (1 - beta1_pow) def test_check_output(self):
moment2_unbiased = moment2_out / (1 - beta2_pow) for i in range(self.num_steps):
param_out, moment1_out, moment2_out, \
beta1_pow_out, beta2_pow_out = lamb_step(self.inputs, self.attrs)
r_1 = np.linalg.norm(param) self.outputs = {
r_2 = np.linalg.norm(moment1_unbiased / (np.sqrt(moment2_unbiased) + epsilon 'Moment1Out': moment1_out,
) + weight_decay * param) 'Moment2Out': moment2_out,
if r_1 > 0.0 and r_2 > 0.0: 'ParamOut': param_out,
lr_t = lr * r_1 / r_2 'Beta1PowOut': beta1_pow_out,
else: 'Beta2PowOut': beta2_pow_out
lr_t = 1.0 }
param_out = param - lr_t * (moment1_unbiased / ( # Verify output for this step
np.sqrt(moment2_unbiased) + epsilon) + weight_decay * param) self.check_output()
beta1_pow_out = beta1_pow * beta1 # Output of this step becomes input for next step
beta2_pow_out = beta2_pow * beta2 self.inputs['Param'] = param_out
self.inputs['Moment1'] = moment1_out
self.inputs['Moment2'] = moment2_out
return param_out, moment1_out, moment2_out, beta1_pow_out, beta2_pow_out # Update powers of Beta1 and Beta2 for next time step
""" self.inputs['Beta1Pow'] = beta1_pow_out
self.inputs['Beta2Pow'] = beta2_pow_out
# Randomize gradient for next step
self.inputs['Grad'] = np.random.uniform(
-1, 1, (102, 105)).astype("float32")
support_types = get_xpu_op_support_types('lamb')
for stype in support_types:
create_test_class(globals(), XPUTestLambOp, stype)
if __name__ == "__main__": if __name__ == "__main__":
paddle.enable_static() paddle.enable_static()
......
...@@ -108,6 +108,7 @@ class Lamb(Optimizer): ...@@ -108,6 +108,7 @@ class Lamb(Optimizer):
parameters=None, parameters=None,
grad_clip=None, grad_clip=None,
exclude_from_weight_decay_fn=None, exclude_from_weight_decay_fn=None,
multi_precision=False,
name=None): name=None):
assert learning_rate is not None assert learning_rate is not None
assert beta1 is not None assert beta1 is not None
...@@ -134,7 +135,7 @@ class Lamb(Optimizer): ...@@ -134,7 +135,7 @@ class Lamb(Optimizer):
self._master_weights = {} self._master_weights = {}
self._used_master_weights = {} self._used_master_weights = {}
# TODO(zengjinle): expose API as soon as possible # TODO(zengjinle): expose API as soon as possible
self._multi_precision = False self._multi_precision = multi_precision
def _get_parameter(self, name, scope=None): def _get_parameter(self, name, scope=None):
if scope is None: if scope is None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册