未验证 提交 700205e8 编写于 作者: Z zhaoyingli 提交者: GitHub

fix cpu adamw problem for np.float64 (#35124)

上级 91ba86b1
......@@ -31,19 +31,22 @@ class AdamWFunctor;
template <typename T>
class AdamWFunctor<T, CPUAdamW> {
private:
const float coeff_;
const float learning_rate_;
const T coeff_;
const T* lr_;
T* param_;
public:
AdamWFunctor(const float& coeff, const float& learning_rate, T* param)
: coeff_(coeff), learning_rate_(learning_rate), param_(param) {}
AdamWFunctor(const T coeff, const T* lr, T* param)
: coeff_(coeff), lr_(lr), param_(param) {}
inline HOSTDEVICE void operator()(size_t numel) const {
Eigen::Map<Eigen::Array<T, 1, Eigen::Dynamic>> param{
param_, static_cast<Eigen::Index>(numel)};
T lr = *lr_;
// Calculation
param = param * (1.0f - learning_rate_ * coeff_);
param = param * (1 - lr * coeff_);
}
};
......@@ -183,7 +186,7 @@ class AdamWOpKernel : public AdamOpKernel<DeviceContext, T> {
return;
}
float coeff = ctx.Attr<float>("coeff");
T coeff = static_cast<T>(ctx.Attr<float>("coeff"));
auto* lr = ctx.Input<LoDTensor>("LearningRate");
LoDTensor* param;
......@@ -195,9 +198,7 @@ class AdamWOpKernel : public AdamOpKernel<DeviceContext, T> {
param = const_cast<LoDTensor*>(ctx.Input<LoDTensor>("Param"));
}
// AdamWFunctor(float coeff, const float* learning_rate, T* parma)
AdamWFunctor<T, CPUAdamW> functor(coeff, *lr->data<float>(),
param->data<T>());
AdamWFunctor<T, CPUAdamW> functor(coeff, lr->data<T>(), param->data<T>());
functor(param->numel());
AdamOpKernel<DeviceContext, T>::Compute(ctx);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册