未验证 提交 700205e8 编写于 作者: Z zhaoyingli 提交者: GitHub

fix cpu adamw problem for np.float64 (#35124)

上级 91ba86b1
...@@ -31,19 +31,22 @@ class AdamWFunctor; ...@@ -31,19 +31,22 @@ class AdamWFunctor;
template <typename T> template <typename T>
class AdamWFunctor<T, CPUAdamW> { class AdamWFunctor<T, CPUAdamW> {
private: private:
const float coeff_; const T coeff_;
const float learning_rate_; const T* lr_;
T* param_; T* param_;
public: public:
AdamWFunctor(const float& coeff, const float& learning_rate, T* param) AdamWFunctor(const T coeff, const T* lr, T* param)
: coeff_(coeff), learning_rate_(learning_rate), param_(param) {} : coeff_(coeff), lr_(lr), param_(param) {}
inline HOSTDEVICE void operator()(size_t numel) const { inline HOSTDEVICE void operator()(size_t numel) const {
Eigen::Map<Eigen::Array<T, 1, Eigen::Dynamic>> param{ Eigen::Map<Eigen::Array<T, 1, Eigen::Dynamic>> param{
param_, static_cast<Eigen::Index>(numel)}; param_, static_cast<Eigen::Index>(numel)};
T lr = *lr_;
// Calculation // Calculation
param = param * (1.0f - learning_rate_ * coeff_); param = param * (1 - lr * coeff_);
} }
}; };
...@@ -183,7 +186,7 @@ class AdamWOpKernel : public AdamOpKernel<DeviceContext, T> { ...@@ -183,7 +186,7 @@ class AdamWOpKernel : public AdamOpKernel<DeviceContext, T> {
return; return;
} }
float coeff = ctx.Attr<float>("coeff"); T coeff = static_cast<T>(ctx.Attr<float>("coeff"));
auto* lr = ctx.Input<LoDTensor>("LearningRate"); auto* lr = ctx.Input<LoDTensor>("LearningRate");
LoDTensor* param; LoDTensor* param;
...@@ -195,9 +198,7 @@ class AdamWOpKernel : public AdamOpKernel<DeviceContext, T> { ...@@ -195,9 +198,7 @@ class AdamWOpKernel : public AdamOpKernel<DeviceContext, T> {
param = const_cast<LoDTensor*>(ctx.Input<LoDTensor>("Param")); param = const_cast<LoDTensor*>(ctx.Input<LoDTensor>("Param"));
} }
// AdamWFunctor(float coeff, const float* learning_rate, T* parma) AdamWFunctor<T, CPUAdamW> functor(coeff, lr->data<T>(), param->data<T>());
AdamWFunctor<T, CPUAdamW> functor(coeff, *lr->data<float>(),
param->data<T>());
functor(param->numel()); functor(param->numel());
AdamOpKernel<DeviceContext, T>::Compute(ctx); AdamOpKernel<DeviceContext, T>::Compute(ctx);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册