From 87667c6636768fc19d6090019eb98070e2fa8ab1 Mon Sep 17 00:00:00 2001 From: wbn <66299196+wbn520@users.noreply.github.com> Date: Mon, 16 May 2022 14:35:30 +0800 Subject: [PATCH] Add the new XDNN implementation. test=kunlun (#42683) * Add the new XDNN implementation. test=kunlun * Add the new XDNN implementation. test=kunlun * Modify the code based on review, test=kunlun --- cmake/external/xpu.cmake | 4 +- paddle/fluid/operators/log_loss_op_xpu.cc | 83 +++----- .../fluid/operators/optimizers/lamb_op_xpu.cc | 157 ++++++-------- .../operators/optimizers/rmsprop_op_xpu.cc | 191 ++++++++---------- .../fluid/platform/device/xpu/xpu1_op_list.h | 5 + 5 files changed, 179 insertions(+), 261 deletions(-) diff --git a/cmake/external/xpu.cmake b/cmake/external/xpu.cmake index 2c7f28b3a5..1c4a424995 100644 --- a/cmake/external/xpu.cmake +++ b/cmake/external/xpu.cmake @@ -9,7 +9,7 @@ SET(XPU_RT_LIB_NAME "libxpurt.so") if(NOT DEFINED XPU_BASE_URL) SET(XPU_BASE_URL_WITHOUT_DATE "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev") - SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220510") + SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220511") else() SET(XPU_BASE_URL "${XPU_BASE_URL}") endif() @@ -17,7 +17,7 @@ endif() # ubuntu and centos: use output by XDNN API team if(NOT DEFINED XPU_XDNN_BASE_URL) SET(XPU_XDNN_BASE_URL_WITHOUT_DATE "https://klx-sdk-release-public.su.bcebos.com/xdnn/dev") - SET(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL_WITHOUT_DATE}/20220510") + SET(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL_WITHOUT_DATE}/20220511") else() SET(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL}") endif() diff --git a/paddle/fluid/operators/log_loss_op_xpu.cc b/paddle/fluid/operators/log_loss_op_xpu.cc index ead6f94417..fee1f56ebd 100644 --- a/paddle/fluid/operators/log_loss_op_xpu.cc +++ b/paddle/fluid/operators/log_loss_op_xpu.cc @@ -12,6 +12,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/device/device_wrapper.h" namespace paddle { namespace operators { @@ -21,67 +22,47 @@ template class LogLossXPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - /*** TODO wait XDNN new interface - auto* predict = ctx.Input("Predicted"); - auto* labels = ctx.Input("Labels"); - auto* loss = ctx.Output("Loss"); - auto epsilon = static_cast(ctx.Attr("epsilon")); - loss->mutable_data(ctx.GetPlace()); - int n = predict->numel(); - auto& dev_ctx = ctx.template device_context(); - int r = - xpu::log_loss_fwd(dev_ctx.x_context(), n, epsilon, - predict->data(), - labels->data(), loss->data()); - PADDLE_ENFORCE_EQ( - r, xpu::Error_t::SUCCESS, - platform::errors::External( - "XPU log_loss kernel return wrong value[%d], please check - whether " - "Baidu Kunlun Card is properly installed.", - r)); - ***/ + auto* predict = ctx.Input("Predicted"); + auto* labels = ctx.Input("Labels"); + auto* loss = ctx.Output("Loss"); + auto epsilon = static_cast(ctx.Attr("epsilon")); + loss->mutable_data(ctx.GetPlace()); + int n = predict->numel(); + auto& dev_ctx = ctx.template device_context(); + int r = xpu::log_loss(dev_ctx.x_context(), predict->data(), + labels->data(), loss->data(), n, epsilon); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "log_loss"); } }; template class LogLossGradXPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - /*** TODO wait XDNN new interface - - auto* predict = ctx.Input("Predicted"); - auto* labels = ctx.Input("Labels"); - auto* dloss = ctx.Input(framework::GradVarName("Loss")); - auto* dpred = ctx.Output(framework::GradVarName("Predicted")); - if (!dpred) { - return; - } - auto epsilon = static_cast(ctx.Attr("epsilon")); - dpred->mutable_data(ctx.GetPlace()); - int n = predict->numel(); - auto& dev_ctx = ctx.template device_context(); - int r = xpu::log_loss_bwd(dev_ctx.x_context(), n, epsilon, - predict->data(), labels->data(), - dloss->data(), dpred->data()); - PADDLE_ENFORCE_EQ( - r, xpu::Error_t::SUCCESS, - platform::errors::External( - "XPU log_loss kernel return wrong value[%d], please check - whether " - "Baidu Kunlun Card is properly installed.", - r)); - ***/ + auto* predict = ctx.Input("Predicted"); + auto* labels = ctx.Input("Labels"); + auto* dloss = ctx.Input(framework::GradVarName("Loss")); + auto* dpred = ctx.Output(framework::GradVarName("Predicted")); + if (!dpred) { + return; + } + auto epsilon = static_cast(ctx.Attr("epsilon")); + dpred->mutable_data(ctx.GetPlace()); + int n = predict->numel(); + auto& dev_ctx = ctx.template device_context(); + int r = xpu::log_loss_grad(dev_ctx.x_context(), predict->data(), + labels->data(), dloss->data(), + dpred->data(), n, epsilon); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "log_loss_grad"); } }; } // namespace operators } // namespace paddle -// namespace ops = paddle::operators; -// REGISTER_OP_XPU_KERNEL( -// log_loss, ops::LogLossXPUKernel); -// REGISTER_OP_XPU_KERNEL( -// log_loss_grad, -// ops::LogLossGradXPUKernel); +namespace ops = paddle::operators; +REGISTER_OP_XPU_KERNEL( + log_loss, ops::LogLossXPUKernel); +REGISTER_OP_XPU_KERNEL( + log_loss_grad, + ops::LogLossGradXPUKernel); #endif diff --git a/paddle/fluid/operators/optimizers/lamb_op_xpu.cc b/paddle/fluid/operators/optimizers/lamb_op_xpu.cc index 643f70b260..7aa5783a01 100644 --- a/paddle/fluid/operators/optimizers/lamb_op_xpu.cc +++ b/paddle/fluid/operators/optimizers/lamb_op_xpu.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/optimizers/lamb_op.h" #include "gflags/gflags.h" +#include "paddle/fluid/platform/device/device_wrapper.h" namespace paddle { namespace operators { @@ -25,111 +26,75 @@ template class LambOpXPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - /*** TODO wait XDNN new interface - using paddle::framework::LoDTensor; - const auto* param_var = ctx.InputVar("Param"); - PADDLE_ENFORCE_EQ(param_var->IsType(), true, - platform::errors::InvalidArgument( - "The Var(%s)'s type should be LoDTensor, " - "but the received is %s", - ctx.InputNames("Param").front(), - framework::ToTypeName(param_var->Type()))); + using paddle::framework::LoDTensor; + const auto* param_var = ctx.InputVar("Param"); + PADDLE_ENFORCE_EQ(param_var->IsType(), true, + platform::errors::InvalidArgument( + "The Var(%s)'s type should be LoDTensor, " + "but the received is %s", + ctx.InputNames("Param").front(), + framework::ToTypeName(param_var->Type()))); - using paddle::framework::LoDTensor; + using paddle::framework::LoDTensor; - // inputs - T epsilon = static_cast(ctx.Attr("epsilon")); - T weight_decay = static_cast(ctx.Attr("weight_decay")); - T beta1 = static_cast(ctx.Attr("beta1")); - T beta2 = static_cast(ctx.Attr("beta2")); - auto& param = GET_DATA_SAFELY(ctx.Input("Param"), "Input", - "Param", "Lamb"); - auto* grad_var = ctx.InputVar("Grad"); - auto& mom1 = GET_DATA_SAFELY(ctx.Input("Moment1"), "Input", - "Moment1", "Lamb"); - auto& mom2 = GET_DATA_SAFELY(ctx.Input("Moment2"), "Input", - "Moment2", "Lamb"); - auto& lr = GET_DATA_SAFELY(ctx.Input("LearningRate"), - "Input", - "LearningRate", "Lamb"); + // inputs + T epsilon = static_cast(ctx.Attr("epsilon")); + T weight_decay = static_cast(ctx.Attr("weight_decay")); + T beta1 = static_cast(ctx.Attr("beta1")); + T beta2 = static_cast(ctx.Attr("beta2")); + auto& param = GET_DATA_SAFELY(ctx.Input("Param"), "Input", + "Param", "Lamb"); + auto* grad_var = ctx.InputVar("Grad"); + auto& mom1 = GET_DATA_SAFELY(ctx.Input("Moment1"), "Input", + "Moment1", "Lamb"); + auto& mom2 = GET_DATA_SAFELY(ctx.Input("Moment2"), "Input", + "Moment2", "Lamb"); + auto& lr = GET_DATA_SAFELY(ctx.Input("LearningRate"), "Input", + "LearningRate", "Lamb"); - auto& beta1_pow = GET_DATA_SAFELY(ctx.Input("Beta1Pow"), - "Input", - "Beta1Pow", "Lamb"); - auto& beta2_pow = GET_DATA_SAFELY(ctx.Input("Beta2Pow"), - "Input", - "Beta2Pow", "Lamb"); + auto& beta1_pow = GET_DATA_SAFELY(ctx.Input("Beta1Pow"), "Input", + "Beta1Pow", "Lamb"); + auto& beta2_pow = GET_DATA_SAFELY(ctx.Input("Beta2Pow"), "Input", + "Beta2Pow", "Lamb"); - auto& param_out = GET_DATA_SAFELY(ctx.Output("ParamOut"), - "Output", "ParamOut", "Lamb"); - auto& mom1_out = GET_DATA_SAFELY(ctx.Output("Moment1Out"), - "Output", "Moment1Out", "Lamb"); - auto& mom2_out = GET_DATA_SAFELY(ctx.Output("Moment2Out"), - "Output", "Moment2Out", "Lamb"); - auto& beta1_pow_out = - GET_DATA_SAFELY(ctx.Output("Beta1PowOut"), - "Output", "Beta1PowOut", "Lamb"); - auto& beta2_pow_out = - GET_DATA_SAFELY(ctx.Output("Beta2PowOut"), - "Output", "Beta2PowOut", "Lamb"); - auto& dev_ctx = ctx.template device_context(); + auto& param_out = GET_DATA_SAFELY(ctx.Output("ParamOut"), + "Output", "ParamOut", "Lamb"); + auto& mom1_out = GET_DATA_SAFELY(ctx.Output("Moment1Out"), + "Output", "Moment1Out", "Lamb"); + auto& mom2_out = GET_DATA_SAFELY(ctx.Output("Moment2Out"), + "Output", "Moment2Out", "Lamb"); + auto& beta1_pow_out = GET_DATA_SAFELY(ctx.Output("Beta1PowOut"), + "Output", "Beta1PowOut", "Lamb"); + auto& beta2_pow_out = GET_DATA_SAFELY(ctx.Output("Beta2PowOut"), + "Output", "Beta2PowOut", "Lamb"); + auto& dev_ctx = ctx.template device_context(); - if (grad_var->IsType()) { - auto& grad = *ctx.Input("Grad"); - int r = xpu::lamb(dev_ctx.x_context(), grad.template data(), - mom1.template data(), mom2.template data(), - param.template data(), beta1_pow.template - data(), - beta2_pow.template data(), beta1, beta2, epsilon, - weight_decay, lr.template data(), - mom1_out.template mutable_data(ctx.GetPlace()), - mom2_out.template mutable_data(ctx.GetPlace()), - param_out.template mutable_data(ctx.GetPlace()), - beta1_pow_out.template - mutable_data(ctx.GetPlace()), - beta2_pow_out.template - mutable_data(ctx.GetPlace()), - param.numel()); + if (grad_var->IsType()) { + auto& grad = *ctx.Input("Grad"); + int r = xpu::lamb( + dev_ctx.x_context(), grad.template data(), mom1.template data(), + mom2.template data(), param.template data(), + beta1_pow.template data(), beta2_pow.template data(), + mom1_out.template mutable_data(ctx.GetPlace()), + mom2_out.template mutable_data(ctx.GetPlace()), + param_out.template mutable_data(ctx.GetPlace()), + beta1_pow_out.template mutable_data(ctx.GetPlace()), + beta2_pow_out.template mutable_data(ctx.GetPlace()), beta1, beta2, + epsilon, weight_decay, lr.template data(), param.numel()); - if (r == xpu::Error_t::INVALID_PARAM) { - PADDLE_ENFORCE_EQ( - r, xpu::Error_t::SUCCESS, - platform::errors::InvalidArgument( - "XPU kernel error of LambOp, error message: INVALID_PARAM, " - "please check your input & output.")); - } else if (r == xpu::Error_t::RUNTIME_ERROR) { - PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS, - platform::errors::Unavailable( - "XPU kernel error of LambOp, error message: " - "RUNTIME_ERROR, please check whether Baidu " - "Kunlun Card is properly installed.")); - } else if (r == xpu::Error_t::NO_ENOUGH_WORKSPACE) { - PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS, - platform::errors::ResourceExhausted( - "XPU kernel error of LambOp, error " - "message: NO_ENOUGH_WORKSPACE, XPU " - "has no enough memory.")); - } else { - PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS, - platform::errors::ResourceExhausted( - "XPU kernel error of LambOp, error " - "message: OTHER " - "XPU API returns error code: %d.", - r)); - } - } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "Variable type not supported by lamb_op. Expect LoDTensor, " - "but got %s", - framework::ToTypeName(param_var->Type()))); - } - **/ + PADDLE_ENFORCE_XDNN_SUCCESS(r, "lamb"); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Variable type not supported by lamb_op. Expect LoDTensor, " + "but got %s", + framework::ToTypeName(param_var->Type()))); + } } }; } // namespace operators } // namespace paddle -// namespace ops = paddle::operators; -// REGISTER_OP_XPU_KERNEL( -// lamb, ops::LambOpXPUKernel); +namespace ops = paddle::operators; +REGISTER_OP_XPU_KERNEL( + lamb, ops::LambOpXPUKernel); #endif diff --git a/paddle/fluid/operators/optimizers/rmsprop_op_xpu.cc b/paddle/fluid/operators/optimizers/rmsprop_op_xpu.cc index 873056c7f6..b53d51686c 100644 --- a/paddle/fluid/operators/optimizers/rmsprop_op_xpu.cc +++ b/paddle/fluid/operators/optimizers/rmsprop_op_xpu.cc @@ -17,6 +17,7 @@ limitations under the License. */ #include #include #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/device/device_wrapper.h" namespace paddle { namespace operators { @@ -40,122 +41,88 @@ template class RmspropOpXPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - /*** TODO wait XDNN new interface - using paddle::framework::LoDTensor; - - // check Param & Grad tensor type - const auto* param_var = ctx.InputVar("Param"); - PADDLE_ENFORCE_EQ(param_var->IsType(), true, - platform::errors::InvalidArgument( - "Tensor holds the wrong type,Expected Var(%s)'s " - "type is LoDTensor, " - "but the received is %s", - ctx.InputNames("Param").front(), - framework::ToTypeName(param_var->Type()))); - - const auto* grad_var = ctx.InputVar("Grad"); - PADDLE_ENFORCE_EQ(grad_var->IsType(), true, - platform::errors::InvalidArgument( - "Tensor holds the wrong type,Expected Var(%s)'s " - "type is LoDTensor, " - "but the received is %s", - ctx.InputNames("Grad").front(), - framework::ToTypeName(grad_var->Type()))); - - // inputs - auto& param = GET_DATA_SAFELY(ctx.Input("Param"), "Input", - "Param", "Rmsprop"); - auto& meanSquare = GET_DATA_SAFELY(ctx.Input("MeanSquare"), - "Input", "MeanSquare", "Rmsprop"); - auto& grad = GET_DATA_SAFELY(ctx.Input("Grad"), "Input", - "Grad", - "Rmsprop"); - auto& mom = GET_DATA_SAFELY(ctx.Input("Moment"), "Input", - "Moment", "Rmsprop"); - - auto* learning_rate = ctx.Input("LearningRate"); - PADDLE_ENFORCE_EQ(learning_rate->dims().size(), 1, - platform::errors::InvalidArgument( - "learining rate should have dimension = 1." - " But received learning rate dim [%s] ", - learning_rate->dims().size())); - T lr = static_cast(GetAttrFromTensor(learning_rate)); - - // constants - T epsilon = static_cast(ctx.Attr("epsilon")); - T decay = static_cast(ctx.Attr("decay")); - T momentum = static_cast(ctx.Attr("momentum")); - - // outputs - auto& param_out = GET_DATA_SAFELY(ctx.Output("ParamOut"), - "Output", "ParamOut", "Rmsprop"); - auto& mom_out = GET_DATA_SAFELY(ctx.Output("MomentOut"), - "Output", "MomentOut", "Rmsprop"); - auto& mom_sqrt_out = - GET_DATA_SAFELY(ctx.Output("MeanSquareOut"), - "Output", "MeanSquareOut", - "Rmsprop"); - auto& dev_ctx = ctx.template device_context(); - - ///// rmsprop优化算法 - /// - /// ms_out[i] = rho * ms[i] + (1 - rho) * (g[i] * g[i]); - /// - /// mom_out[i] = momentum * mom[i] + lr * - /// (g[i] / ((float)sqrt(ms_out[i] + epsilon))); - /// - /// p_out[i] = p[i] - mom_out[i]; - /// DLL_EXPORT int rmsprop(Context* ctx, const float* p, - /// const float* ms, const float* g, const float* mom, - /// float epsilon, float rho, float momentum, float lr, - /// float *ms_out, float *mom_out, float *p_out, int n) - int r = xpu::rmsprop(dev_ctx.x_context(), param.template data(), - meanSquare.template data(), grad.template - data(), - mom.template data(), epsilon, decay, momentum, - lr, - mom_sqrt_out.template - mutable_data(ctx.GetPlace()), - mom_out.template mutable_data(ctx.GetPlace()), - param_out.template mutable_data(ctx.GetPlace()), - param.numel()); - - if (r == xpu::Error_t::INVALID_PARAM) { - PADDLE_ENFORCE_EQ( - r, xpu::Error_t::SUCCESS, - platform::errors::InvalidArgument( - "XPU kernel error of RmspropOp, error message: INVALID_PARAM, - " - "please check your input & output.")); - } else if (r == xpu::Error_t::RUNTIME_ERROR) { - PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS, - platform::errors::Unavailable( - "XPU kernel error of RmspropOp, error message: " - "RUNTIME_ERROR, please check whether Baidu " - "Kunlun Card is properly installed.")); - } else if (r == xpu::Error_t::NO_ENOUGH_WORKSPACE) { - PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS, - platform::errors::ResourceExhausted( - "XPU kernel error of RmspropOp, error " - "message: NO_ENOUGH_WORKSPACE, XPU " - "has no enough memory.")); - } else { - PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS, - platform::errors::ResourceExhausted( - "XPU kernel error of RmspropOp, error " - "message: OTHER " - "XPU API returns error code: %d.", - r)); - } - ***/ + using paddle::framework::LoDTensor; + + // check Param & Grad tensor type + const auto* param_var = ctx.InputVar("Param"); + PADDLE_ENFORCE_EQ(param_var->IsType(), true, + platform::errors::InvalidArgument( + "Tensor holds the wrong type,Expected Var(%s)'s " + "type is LoDTensor, " + "but the received is %s", + ctx.InputNames("Param").front(), + framework::ToTypeName(param_var->Type()))); + + const auto* grad_var = ctx.InputVar("Grad"); + PADDLE_ENFORCE_EQ(grad_var->IsType(), true, + platform::errors::InvalidArgument( + "Tensor holds the wrong type,Expected Var(%s)'s " + "type is LoDTensor, " + "but the received is %s", + ctx.InputNames("Grad").front(), + framework::ToTypeName(grad_var->Type()))); + + // inputs + auto& param = GET_DATA_SAFELY(ctx.Input("Param"), "Input", + "Param", "Rmsprop"); + auto& meanSquare = GET_DATA_SAFELY(ctx.Input("MeanSquare"), + "Input", "MeanSquare", "Rmsprop"); + auto& grad = GET_DATA_SAFELY(ctx.Input("Grad"), "Input", "Grad", + "Rmsprop"); + auto& mom = GET_DATA_SAFELY(ctx.Input("Moment"), "Input", + "Moment", "Rmsprop"); + + auto* learning_rate = ctx.Input("LearningRate"); + PADDLE_ENFORCE_EQ(learning_rate->dims().size(), 1, + platform::errors::InvalidArgument( + "learining rate should have dimension = 1." + " But received learning rate dim [%s] ", + learning_rate->dims().size())); + T lr = static_cast(GetAttrFromTensor(learning_rate)); + + // constants + T epsilon = static_cast(ctx.Attr("epsilon")); + T decay = static_cast(ctx.Attr("decay")); + T momentum = static_cast(ctx.Attr("momentum")); + + // outputs + auto& param_out = GET_DATA_SAFELY(ctx.Output("ParamOut"), + "Output", "ParamOut", "Rmsprop"); + auto& mom_out = GET_DATA_SAFELY(ctx.Output("MomentOut"), + "Output", "MomentOut", "Rmsprop"); + auto& mom_sqrt_out = GET_DATA_SAFELY(ctx.Output("MeanSquareOut"), + "Output", "MeanSquareOut", "Rmsprop"); + auto& dev_ctx = ctx.template device_context(); + + ///// rmsprop优化算法 + /// + /// ms_out[i] = rho * ms[i] + (1 - rho) * (g[i] * g[i]); + /// + /// mom_out[i] = momentum * mom[i] + lr * + /// (g[i] / ((float)sqrt(ms_out[i] + epsilon))); + /// + /// p_out[i] = p[i] - mom_out[i]; + /// DLL_EXPORT int rmsprop(Context* ctx, const float* p, + /// const float* ms, const float* g, const float* mom, + /// float epsilon, float rho, float momentum, float lr, + /// float *ms_out, float *mom_out, float *p_out, int n) + int r = xpu::rmsprop(dev_ctx.x_context(), grad.template data(), + param.template data(), + meanSquare.template data(), mom.template data(), + param_out.template mutable_data(ctx.GetPlace()), + mom_sqrt_out.template mutable_data(ctx.GetPlace()), + mom_out.template mutable_data(ctx.GetPlace()), + epsilon, decay, momentum, lr, param.numel()); + + PADDLE_ENFORCE_XDNN_SUCCESS(r, "rmsprop"); } }; } // namespace operators } // namespace paddle -// namespace ops = paddle::operators; -// REGISTER_OP_XPU_KERNEL( -// rmsprop, -// ops::RmspropOpXPUKernel); +namespace ops = paddle::operators; +REGISTER_OP_XPU_KERNEL( + rmsprop, + ops::RmspropOpXPUKernel); #endif diff --git a/paddle/fluid/platform/device/xpu/xpu1_op_list.h b/paddle/fluid/platform/device/xpu/xpu1_op_list.h index e8c3eee5b5..a76bdd4ae9 100644 --- a/paddle/fluid/platform/device/xpu/xpu1_op_list.h +++ b/paddle/fluid/platform/device/xpu/xpu1_op_list.h @@ -145,6 +145,7 @@ XPUOpMap& get_kl1_ops() { {"hard_switch", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"iou_similarity", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"lamb", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"layer_norm_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"layer_norm", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, @@ -174,6 +175,9 @@ XPUOpMap& get_kl1_ops() { pOpKernelType(vartype::INT32, XPUPlace()), pOpKernelType(vartype::INT64, XPUPlace()), pOpKernelType(vartype::FP32, XPUPlace())})}, + {"log_loss_grad", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"log_loss", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"logsumexp", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"log", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"lookup_table_v2_grad", @@ -232,6 +236,7 @@ XPUOpMap& get_kl1_ops() { pOpKernelType(vartype::INT32, XPUPlace()), pOpKernelType(vartype::BOOL, XPUPlace()), pOpKernelType(vartype::FP32, XPUPlace())})}, + {"rmsprop", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"rnn_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"rnn", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"roi_align_grad", -- GitLab