未验证 提交 87667c66 编写于 作者: W wbn 提交者: GitHub

Add the new XDNN implementation. test=kunlun (#42683)

* Add the new XDNN implementation. test=kunlun

* Add the new XDNN implementation. test=kunlun

* Modify the code based on review, test=kunlun
上级 34cda80b
...@@ -9,7 +9,7 @@ SET(XPU_RT_LIB_NAME "libxpurt.so") ...@@ -9,7 +9,7 @@ SET(XPU_RT_LIB_NAME "libxpurt.so")
if(NOT DEFINED XPU_BASE_URL) if(NOT DEFINED XPU_BASE_URL)
SET(XPU_BASE_URL_WITHOUT_DATE "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev") SET(XPU_BASE_URL_WITHOUT_DATE "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev")
SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220510") SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220511")
else() else()
SET(XPU_BASE_URL "${XPU_BASE_URL}") SET(XPU_BASE_URL "${XPU_BASE_URL}")
endif() endif()
...@@ -17,7 +17,7 @@ endif() ...@@ -17,7 +17,7 @@ endif()
# ubuntu and centos: use output by XDNN API team # ubuntu and centos: use output by XDNN API team
if(NOT DEFINED XPU_XDNN_BASE_URL) if(NOT DEFINED XPU_XDNN_BASE_URL)
SET(XPU_XDNN_BASE_URL_WITHOUT_DATE "https://klx-sdk-release-public.su.bcebos.com/xdnn/dev") SET(XPU_XDNN_BASE_URL_WITHOUT_DATE "https://klx-sdk-release-public.su.bcebos.com/xdnn/dev")
SET(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL_WITHOUT_DATE}/20220510") SET(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL_WITHOUT_DATE}/20220511")
else() else()
SET(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL}") SET(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL}")
endif() endif()
......
...@@ -12,6 +12,7 @@ limitations under the License. */ ...@@ -12,6 +12,7 @@ limitations under the License. */
#include <memory> #include <memory>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -21,67 +22,47 @@ template <typename DeviceContext, typename T, typename AttrType = T> ...@@ -21,67 +22,47 @@ template <typename DeviceContext, typename T, typename AttrType = T>
class LogLossXPUKernel : public framework::OpKernel<T> { class LogLossXPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
/*** TODO wait XDNN new interface auto* predict = ctx.Input<Tensor>("Predicted");
auto* predict = ctx.Input<Tensor>("Predicted"); auto* labels = ctx.Input<Tensor>("Labels");
auto* labels = ctx.Input<Tensor>("Labels"); auto* loss = ctx.Output<Tensor>("Loss");
auto* loss = ctx.Output<Tensor>("Loss"); auto epsilon = static_cast<T>(ctx.Attr<AttrType>("epsilon"));
auto epsilon = static_cast<T>(ctx.Attr<AttrType>("epsilon")); loss->mutable_data<T>(ctx.GetPlace());
loss->mutable_data<T>(ctx.GetPlace()); int n = predict->numel();
int n = predict->numel(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto& dev_ctx = ctx.template device_context<DeviceContext>(); int r = xpu::log_loss(dev_ctx.x_context(), predict->data<T>(),
int r = labels->data<T>(), loss->data<T>(), n, epsilon);
xpu::log_loss_fwd(dev_ctx.x_context(), n, epsilon, PADDLE_ENFORCE_XDNN_SUCCESS(r, "log_loss");
predict->data<T>(),
labels->data<T>(), loss->data<T>());
PADDLE_ENFORCE_EQ(
r, xpu::Error_t::SUCCESS,
platform::errors::External(
"XPU log_loss kernel return wrong value[%d], please check
whether "
"Baidu Kunlun Card is properly installed.",
r));
***/
} }
}; };
template <typename DeviceContext, typename T, typename AttrType = T> template <typename DeviceContext, typename T, typename AttrType = T>
class LogLossGradXPUKernel : public framework::OpKernel<T> { class LogLossGradXPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
/*** TODO wait XDNN new interface auto* predict = ctx.Input<Tensor>("Predicted");
auto* labels = ctx.Input<Tensor>("Labels");
auto* predict = ctx.Input<Tensor>("Predicted"); auto* dloss = ctx.Input<Tensor>(framework::GradVarName("Loss"));
auto* labels = ctx.Input<Tensor>("Labels"); auto* dpred = ctx.Output<Tensor>(framework::GradVarName("Predicted"));
auto* dloss = ctx.Input<Tensor>(framework::GradVarName("Loss")); if (!dpred) {
auto* dpred = ctx.Output<Tensor>(framework::GradVarName("Predicted")); return;
if (!dpred) { }
return; auto epsilon = static_cast<T>(ctx.Attr<AttrType>("epsilon"));
} dpred->mutable_data<T>(ctx.GetPlace());
auto epsilon = static_cast<T>(ctx.Attr<AttrType>("epsilon")); int n = predict->numel();
dpred->mutable_data<T>(ctx.GetPlace()); auto& dev_ctx = ctx.template device_context<DeviceContext>();
int n = predict->numel(); int r = xpu::log_loss_grad(dev_ctx.x_context(), predict->data<T>(),
auto& dev_ctx = ctx.template device_context<DeviceContext>(); labels->data<T>(), dloss->data<T>(),
int r = xpu::log_loss_bwd(dev_ctx.x_context(), n, epsilon, dpred->data<T>(), n, epsilon);
predict->data<T>(), labels->data<T>(), PADDLE_ENFORCE_XDNN_SUCCESS(r, "log_loss_grad");
dloss->data<T>(), dpred->data<T>());
PADDLE_ENFORCE_EQ(
r, xpu::Error_t::SUCCESS,
platform::errors::External(
"XPU log_loss kernel return wrong value[%d], please check
whether "
"Baidu Kunlun Card is properly installed.",
r));
***/
} }
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
// namespace ops = paddle::operators; namespace ops = paddle::operators;
// REGISTER_OP_XPU_KERNEL( REGISTER_OP_XPU_KERNEL(
// log_loss, ops::LogLossXPUKernel<paddle::platform::XPUDeviceContext, log_loss, ops::LogLossXPUKernel<paddle::platform::XPUDeviceContext, float>);
// float>); REGISTER_OP_XPU_KERNEL(
// REGISTER_OP_XPU_KERNEL( log_loss_grad,
// log_loss_grad, ops::LogLossGradXPUKernel<paddle::platform::XPUDeviceContext, float>);
// ops::LogLossGradXPUKernel<paddle::platform::XPUDeviceContext, float>);
#endif #endif
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/optimizers/lamb_op.h" #include "paddle/fluid/operators/optimizers/lamb_op.h"
#include "gflags/gflags.h" #include "gflags/gflags.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -25,111 +26,75 @@ template <typename DeviceContext, typename T> ...@@ -25,111 +26,75 @@ template <typename DeviceContext, typename T>
class LambOpXPUKernel : public framework::OpKernel<T> { class LambOpXPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
/*** TODO wait XDNN new interface using paddle::framework::LoDTensor;
using paddle::framework::LoDTensor; const auto* param_var = ctx.InputVar("Param");
const auto* param_var = ctx.InputVar("Param"); PADDLE_ENFORCE_EQ(param_var->IsType<framework::LoDTensor>(), true,
PADDLE_ENFORCE_EQ(param_var->IsType<framework::LoDTensor>(), true, platform::errors::InvalidArgument(
platform::errors::InvalidArgument( "The Var(%s)'s type should be LoDTensor, "
"The Var(%s)'s type should be LoDTensor, " "but the received is %s",
"but the received is %s", ctx.InputNames("Param").front(),
ctx.InputNames("Param").front(), framework::ToTypeName(param_var->Type())));
framework::ToTypeName(param_var->Type())));
using paddle::framework::LoDTensor; using paddle::framework::LoDTensor;
// inputs // inputs
T epsilon = static_cast<T>(ctx.Attr<float>("epsilon")); T epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
T weight_decay = static_cast<T>(ctx.Attr<float>("weight_decay")); T weight_decay = static_cast<T>(ctx.Attr<float>("weight_decay"));
T beta1 = static_cast<T>(ctx.Attr<float>("beta1")); T beta1 = static_cast<T>(ctx.Attr<float>("beta1"));
T beta2 = static_cast<T>(ctx.Attr<float>("beta2")); T beta2 = static_cast<T>(ctx.Attr<float>("beta2"));
auto& param = GET_DATA_SAFELY(ctx.Input<LoDTensor>("Param"), "Input", auto& param = GET_DATA_SAFELY(ctx.Input<LoDTensor>("Param"), "Input",
"Param", "Lamb"); "Param", "Lamb");
auto* grad_var = ctx.InputVar("Grad"); auto* grad_var = ctx.InputVar("Grad");
auto& mom1 = GET_DATA_SAFELY(ctx.Input<LoDTensor>("Moment1"), "Input", auto& mom1 = GET_DATA_SAFELY(ctx.Input<LoDTensor>("Moment1"), "Input",
"Moment1", "Lamb"); "Moment1", "Lamb");
auto& mom2 = GET_DATA_SAFELY(ctx.Input<LoDTensor>("Moment2"), "Input", auto& mom2 = GET_DATA_SAFELY(ctx.Input<LoDTensor>("Moment2"), "Input",
"Moment2", "Lamb"); "Moment2", "Lamb");
auto& lr = GET_DATA_SAFELY(ctx.Input<LoDTensor>("LearningRate"), auto& lr = GET_DATA_SAFELY(ctx.Input<LoDTensor>("LearningRate"), "Input",
"Input", "LearningRate", "Lamb");
"LearningRate", "Lamb");
auto& beta1_pow = GET_DATA_SAFELY(ctx.Input<LoDTensor>("Beta1Pow"), auto& beta1_pow = GET_DATA_SAFELY(ctx.Input<LoDTensor>("Beta1Pow"), "Input",
"Input", "Beta1Pow", "Lamb");
"Beta1Pow", "Lamb"); auto& beta2_pow = GET_DATA_SAFELY(ctx.Input<LoDTensor>("Beta2Pow"), "Input",
auto& beta2_pow = GET_DATA_SAFELY(ctx.Input<LoDTensor>("Beta2Pow"), "Beta2Pow", "Lamb");
"Input",
"Beta2Pow", "Lamb");
auto& param_out = GET_DATA_SAFELY(ctx.Output<LoDTensor>("ParamOut"), auto& param_out = GET_DATA_SAFELY(ctx.Output<LoDTensor>("ParamOut"),
"Output", "ParamOut", "Lamb"); "Output", "ParamOut", "Lamb");
auto& mom1_out = GET_DATA_SAFELY(ctx.Output<LoDTensor>("Moment1Out"), auto& mom1_out = GET_DATA_SAFELY(ctx.Output<LoDTensor>("Moment1Out"),
"Output", "Moment1Out", "Lamb"); "Output", "Moment1Out", "Lamb");
auto& mom2_out = GET_DATA_SAFELY(ctx.Output<LoDTensor>("Moment2Out"), auto& mom2_out = GET_DATA_SAFELY(ctx.Output<LoDTensor>("Moment2Out"),
"Output", "Moment2Out", "Lamb"); "Output", "Moment2Out", "Lamb");
auto& beta1_pow_out = auto& beta1_pow_out = GET_DATA_SAFELY(ctx.Output<LoDTensor>("Beta1PowOut"),
GET_DATA_SAFELY(ctx.Output<LoDTensor>("Beta1PowOut"), "Output", "Beta1PowOut", "Lamb");
"Output", "Beta1PowOut", "Lamb"); auto& beta2_pow_out = GET_DATA_SAFELY(ctx.Output<LoDTensor>("Beta2PowOut"),
auto& beta2_pow_out = "Output", "Beta2PowOut", "Lamb");
GET_DATA_SAFELY(ctx.Output<LoDTensor>("Beta2PowOut"), auto& dev_ctx = ctx.template device_context<DeviceContext>();
"Output", "Beta2PowOut", "Lamb");
auto& dev_ctx = ctx.template device_context<DeviceContext>();
if (grad_var->IsType<framework::LoDTensor>()) { if (grad_var->IsType<framework::LoDTensor>()) {
auto& grad = *ctx.Input<LoDTensor>("Grad"); auto& grad = *ctx.Input<LoDTensor>("Grad");
int r = xpu::lamb(dev_ctx.x_context(), grad.template data<T>(), int r = xpu::lamb(
mom1.template data<T>(), mom2.template data<T>(), dev_ctx.x_context(), grad.template data<T>(), mom1.template data<T>(),
param.template data<T>(), beta1_pow.template mom2.template data<T>(), param.template data<T>(),
data<T>(), beta1_pow.template data<T>(), beta2_pow.template data<T>(),
beta2_pow.template data<T>(), beta1, beta2, epsilon, mom1_out.template mutable_data<T>(ctx.GetPlace()),
weight_decay, lr.template data<T>(), mom2_out.template mutable_data<T>(ctx.GetPlace()),
mom1_out.template mutable_data<T>(ctx.GetPlace()), param_out.template mutable_data<T>(ctx.GetPlace()),
mom2_out.template mutable_data<T>(ctx.GetPlace()), beta1_pow_out.template mutable_data<T>(ctx.GetPlace()),
param_out.template mutable_data<T>(ctx.GetPlace()), beta2_pow_out.template mutable_data<T>(ctx.GetPlace()), beta1, beta2,
beta1_pow_out.template epsilon, weight_decay, lr.template data<T>(), param.numel());
mutable_data<T>(ctx.GetPlace()),
beta2_pow_out.template
mutable_data<T>(ctx.GetPlace()),
param.numel());
if (r == xpu::Error_t::INVALID_PARAM) { PADDLE_ENFORCE_XDNN_SUCCESS(r, "lamb");
PADDLE_ENFORCE_EQ( } else {
r, xpu::Error_t::SUCCESS, PADDLE_THROW(platform::errors::InvalidArgument(
platform::errors::InvalidArgument( "Variable type not supported by lamb_op. Expect LoDTensor, "
"XPU kernel error of LambOp, error message: INVALID_PARAM, " "but got %s",
"please check your input & output.")); framework::ToTypeName(param_var->Type())));
} else if (r == xpu::Error_t::RUNTIME_ERROR) { }
PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS,
platform::errors::Unavailable(
"XPU kernel error of LambOp, error message: "
"RUNTIME_ERROR, please check whether Baidu "
"Kunlun Card is properly installed."));
} else if (r == xpu::Error_t::NO_ENOUGH_WORKSPACE) {
PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS,
platform::errors::ResourceExhausted(
"XPU kernel error of LambOp, error "
"message: NO_ENOUGH_WORKSPACE, XPU "
"has no enough memory."));
} else {
PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS,
platform::errors::ResourceExhausted(
"XPU kernel error of LambOp, error "
"message: OTHER "
"XPU API returns error code: %d.",
r));
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Variable type not supported by lamb_op. Expect LoDTensor, "
"but got %s",
framework::ToTypeName(param_var->Type())));
}
**/
} }
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
// namespace ops = paddle::operators; namespace ops = paddle::operators;
// REGISTER_OP_XPU_KERNEL( REGISTER_OP_XPU_KERNEL(
// lamb, ops::LambOpXPUKernel<paddle::platform::XPUDeviceContext, float>); lamb, ops::LambOpXPUKernel<paddle::platform::XPUDeviceContext, float>);
#endif #endif
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include <gflags/gflags.h> #include <gflags/gflags.h>
#include <iostream> #include <iostream>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -40,122 +41,88 @@ template <typename DeviceContext, typename T> ...@@ -40,122 +41,88 @@ template <typename DeviceContext, typename T>
class RmspropOpXPUKernel : public framework::OpKernel<T> { class RmspropOpXPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
/*** TODO wait XDNN new interface using paddle::framework::LoDTensor;
using paddle::framework::LoDTensor;
// check Param & Grad tensor type
// check Param & Grad tensor type const auto* param_var = ctx.InputVar("Param");
const auto* param_var = ctx.InputVar("Param"); PADDLE_ENFORCE_EQ(param_var->IsType<LoDTensor>(), true,
PADDLE_ENFORCE_EQ(param_var->IsType<LoDTensor>(), true, platform::errors::InvalidArgument(
platform::errors::InvalidArgument( "Tensor holds the wrong type,Expected Var(%s)'s "
"Tensor holds the wrong type,Expected Var(%s)'s " "type is LoDTensor, "
"type is LoDTensor, " "but the received is %s",
"but the received is %s", ctx.InputNames("Param").front(),
ctx.InputNames("Param").front(), framework::ToTypeName(param_var->Type())));
framework::ToTypeName(param_var->Type())));
const auto* grad_var = ctx.InputVar("Grad");
const auto* grad_var = ctx.InputVar("Grad"); PADDLE_ENFORCE_EQ(grad_var->IsType<LoDTensor>(), true,
PADDLE_ENFORCE_EQ(grad_var->IsType<LoDTensor>(), true, platform::errors::InvalidArgument(
platform::errors::InvalidArgument( "Tensor holds the wrong type,Expected Var(%s)'s "
"Tensor holds the wrong type,Expected Var(%s)'s " "type is LoDTensor, "
"type is LoDTensor, " "but the received is %s",
"but the received is %s", ctx.InputNames("Grad").front(),
ctx.InputNames("Grad").front(), framework::ToTypeName(grad_var->Type())));
framework::ToTypeName(grad_var->Type())));
// inputs
// inputs auto& param = GET_DATA_SAFELY(ctx.Input<LoDTensor>("Param"), "Input",
auto& param = GET_DATA_SAFELY(ctx.Input<LoDTensor>("Param"), "Input", "Param", "Rmsprop");
"Param", "Rmsprop"); auto& meanSquare = GET_DATA_SAFELY(ctx.Input<LoDTensor>("MeanSquare"),
auto& meanSquare = GET_DATA_SAFELY(ctx.Input<LoDTensor>("MeanSquare"), "Input", "MeanSquare", "Rmsprop");
"Input", "MeanSquare", "Rmsprop"); auto& grad = GET_DATA_SAFELY(ctx.Input<LoDTensor>("Grad"), "Input", "Grad",
auto& grad = GET_DATA_SAFELY(ctx.Input<LoDTensor>("Grad"), "Input", "Rmsprop");
"Grad", auto& mom = GET_DATA_SAFELY(ctx.Input<LoDTensor>("Moment"), "Input",
"Rmsprop"); "Moment", "Rmsprop");
auto& mom = GET_DATA_SAFELY(ctx.Input<LoDTensor>("Moment"), "Input",
"Moment", "Rmsprop"); auto* learning_rate = ctx.Input<Tensor>("LearningRate");
PADDLE_ENFORCE_EQ(learning_rate->dims().size(), 1,
auto* learning_rate = ctx.Input<Tensor>("LearningRate"); platform::errors::InvalidArgument(
PADDLE_ENFORCE_EQ(learning_rate->dims().size(), 1, "learining rate should have dimension = 1."
platform::errors::InvalidArgument( " But received learning rate dim [%s] ",
"learining rate should have dimension = 1." learning_rate->dims().size()));
" But received learning rate dim [%s] ", T lr = static_cast<T>(GetAttrFromTensor(learning_rate));
learning_rate->dims().size()));
T lr = static_cast<T>(GetAttrFromTensor(learning_rate)); // constants
T epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
// constants T decay = static_cast<T>(ctx.Attr<float>("decay"));
T epsilon = static_cast<T>(ctx.Attr<float>("epsilon")); T momentum = static_cast<T>(ctx.Attr<float>("momentum"));
T decay = static_cast<T>(ctx.Attr<float>("decay"));
T momentum = static_cast<T>(ctx.Attr<float>("momentum")); // outputs
auto& param_out = GET_DATA_SAFELY(ctx.Output<LoDTensor>("ParamOut"),
// outputs "Output", "ParamOut", "Rmsprop");
auto& param_out = GET_DATA_SAFELY(ctx.Output<LoDTensor>("ParamOut"), auto& mom_out = GET_DATA_SAFELY(ctx.Output<LoDTensor>("MomentOut"),
"Output", "ParamOut", "Rmsprop"); "Output", "MomentOut", "Rmsprop");
auto& mom_out = GET_DATA_SAFELY(ctx.Output<LoDTensor>("MomentOut"), auto& mom_sqrt_out = GET_DATA_SAFELY(ctx.Output<LoDTensor>("MeanSquareOut"),
"Output", "MomentOut", "Rmsprop"); "Output", "MeanSquareOut", "Rmsprop");
auto& mom_sqrt_out = auto& dev_ctx = ctx.template device_context<DeviceContext>();
GET_DATA_SAFELY(ctx.Output<LoDTensor>("MeanSquareOut"),
"Output", "MeanSquareOut", ///// rmsprop优化算法
"Rmsprop"); ///
auto& dev_ctx = ctx.template device_context<DeviceContext>(); /// ms_out[i] = rho * ms[i] + (1 - rho) * (g[i] * g[i]);
///
///// rmsprop优化算法 /// mom_out[i] = momentum * mom[i] + lr *
/// /// (g[i] / ((float)sqrt(ms_out[i] + epsilon)));
/// ms_out[i] = rho * ms[i] + (1 - rho) * (g[i] * g[i]); ///
/// /// p_out[i] = p[i] - mom_out[i];
/// mom_out[i] = momentum * mom[i] + lr * /// DLL_EXPORT int rmsprop(Context* ctx, const float* p,
/// (g[i] / ((float)sqrt(ms_out[i] + epsilon))); /// const float* ms, const float* g, const float* mom,
/// /// float epsilon, float rho, float momentum, float lr,
/// p_out[i] = p[i] - mom_out[i]; /// float *ms_out, float *mom_out, float *p_out, int n)
/// DLL_EXPORT int rmsprop(Context* ctx, const float* p, int r = xpu::rmsprop(dev_ctx.x_context(), grad.template data<T>(),
/// const float* ms, const float* g, const float* mom, param.template data<T>(),
/// float epsilon, float rho, float momentum, float lr, meanSquare.template data<T>(), mom.template data<T>(),
/// float *ms_out, float *mom_out, float *p_out, int n) param_out.template mutable_data<T>(ctx.GetPlace()),
int r = xpu::rmsprop(dev_ctx.x_context(), param.template data<T>(), mom_sqrt_out.template mutable_data<T>(ctx.GetPlace()),
meanSquare.template data<T>(), grad.template mom_out.template mutable_data<T>(ctx.GetPlace()),
data<T>(), epsilon, decay, momentum, lr, param.numel());
mom.template data<T>(), epsilon, decay, momentum,
lr, PADDLE_ENFORCE_XDNN_SUCCESS(r, "rmsprop");
mom_sqrt_out.template
mutable_data<T>(ctx.GetPlace()),
mom_out.template mutable_data<T>(ctx.GetPlace()),
param_out.template mutable_data<T>(ctx.GetPlace()),
param.numel());
if (r == xpu::Error_t::INVALID_PARAM) {
PADDLE_ENFORCE_EQ(
r, xpu::Error_t::SUCCESS,
platform::errors::InvalidArgument(
"XPU kernel error of RmspropOp, error message: INVALID_PARAM,
"
"please check your input & output."));
} else if (r == xpu::Error_t::RUNTIME_ERROR) {
PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS,
platform::errors::Unavailable(
"XPU kernel error of RmspropOp, error message: "
"RUNTIME_ERROR, please check whether Baidu "
"Kunlun Card is properly installed."));
} else if (r == xpu::Error_t::NO_ENOUGH_WORKSPACE) {
PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS,
platform::errors::ResourceExhausted(
"XPU kernel error of RmspropOp, error "
"message: NO_ENOUGH_WORKSPACE, XPU "
"has no enough memory."));
} else {
PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS,
platform::errors::ResourceExhausted(
"XPU kernel error of RmspropOp, error "
"message: OTHER "
"XPU API returns error code: %d.",
r));
}
***/
} }
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
// namespace ops = paddle::operators; namespace ops = paddle::operators;
// REGISTER_OP_XPU_KERNEL( REGISTER_OP_XPU_KERNEL(
// rmsprop, rmsprop,
// ops::RmspropOpXPUKernel<paddle::platform::XPUDeviceContext, float>); ops::RmspropOpXPUKernel<paddle::platform::XPUDeviceContext, float>);
#endif #endif
...@@ -145,6 +145,7 @@ XPUOpMap& get_kl1_ops() { ...@@ -145,6 +145,7 @@ XPUOpMap& get_kl1_ops() {
{"hard_switch", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"hard_switch", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"iou_similarity", {"iou_similarity",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"lamb", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"layer_norm_grad", {"layer_norm_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"layer_norm", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"layer_norm", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
...@@ -174,6 +175,9 @@ XPUOpMap& get_kl1_ops() { ...@@ -174,6 +175,9 @@ XPUOpMap& get_kl1_ops() {
pOpKernelType(vartype::INT32, XPUPlace()), pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()), pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})}, pOpKernelType(vartype::FP32, XPUPlace())})},
{"log_loss_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"log_loss", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"logsumexp", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"logsumexp", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"log", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"log", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"lookup_table_v2_grad", {"lookup_table_v2_grad",
...@@ -232,6 +236,7 @@ XPUOpMap& get_kl1_ops() { ...@@ -232,6 +236,7 @@ XPUOpMap& get_kl1_ops() {
pOpKernelType(vartype::INT32, XPUPlace()), pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace()), pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})}, pOpKernelType(vartype::FP32, XPUPlace())})},
{"rmsprop", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"rnn_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"rnn_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"rnn", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"rnn", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"roi_align_grad", {"roi_align_grad",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册