未验证 提交 86a6be1a 编写于 作者: T taixiurong 提交者: GitHub

add xpu_wait & new implementation replace memcpy in adam, adamw (#35437)

上级 1a7b3ff6
...@@ -35,7 +35,7 @@ ELSE () ...@@ -35,7 +35,7 @@ ELSE ()
ENDIF() ENDIF()
SET(XPU_BASE_URL_WITHOUT_DATE "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev") SET(XPU_BASE_URL_WITHOUT_DATE "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev")
SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20210830") SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20210909")
SET(XPU_XRE_URL "${XPU_BASE_URL}/${XPU_XRE_DIR_NAME}.tar.gz" CACHE STRING "" FORCE) SET(XPU_XRE_URL "${XPU_BASE_URL}/${XPU_XRE_DIR_NAME}.tar.gz" CACHE STRING "" FORCE)
SET(XPU_XDNN_URL "${XPU_BASE_URL}/${XPU_XDNN_DIR_NAME}.tar.gz" CACHE STRING "" FORCE) SET(XPU_XDNN_URL "${XPU_BASE_URL}/${XPU_XDNN_DIR_NAME}.tar.gz" CACHE STRING "" FORCE)
SET(XPU_XCCL_URL "${XPU_BASE_URL_WITHOUT_DATE}/20210623/${XPU_XCCL_DIR_NAME}.tar.gz" CACHE STRING "" FORCE) SET(XPU_XCCL_URL "${XPU_BASE_URL_WITHOUT_DATE}/20210623/${XPU_XCCL_DIR_NAME}.tar.gz" CACHE STRING "" FORCE)
......
...@@ -24,6 +24,8 @@ using DDim = framework::DDim; ...@@ -24,6 +24,8 @@ using DDim = framework::DDim;
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class LayerNormXPUKernel : public framework::OpKernel<T> { class LayerNormXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis"); const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
...@@ -39,15 +41,17 @@ class LayerNormXPUKernel : public framework::OpKernel<T> { ...@@ -39,15 +41,17 @@ class LayerNormXPUKernel : public framework::OpKernel<T> {
auto* mean = ctx.Output<Tensor>("Mean"); auto* mean = ctx.Output<Tensor>("Mean");
auto* variance = ctx.Output<Tensor>("Variance"); auto* variance = ctx.Output<Tensor>("Variance");
const auto* x_data = x->data<T>(); const auto* x_data = x->data<T>();
const auto* scale_data = (scale == nullptr ? nullptr : scale->data<T>()); const auto* scale_data =
const auto* bias_data = (bias == nullptr ? nullptr : bias->data<T>()); (scale == nullptr ? nullptr : scale->data<float>());
const auto* bias_data = (bias == nullptr ? nullptr : bias->data<float>());
auto* y_data = y->mutable_data<T>(ctx.GetPlace()); auto* y_data = y->mutable_data<T>(ctx.GetPlace());
auto* mean_data = mean->mutable_data<T>(ctx.GetPlace()); auto* mean_data = mean->mutable_data<float>(ctx.GetPlace());
auto* variance_data = variance->mutable_data<T>(ctx.GetPlace()); auto* variance_data = variance->mutable_data<float>(ctx.GetPlace());
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
int r = xpu::layer_norm(dev_ctx.x_context(), x_data, y_data, left, right, int r = xpu::layer_norm(
epsilon, scale_data, bias_data, mean_data, dev_ctx.x_context(), reinterpret_cast<const XPUType*>(x_data),
variance_data); reinterpret_cast<XPUType*>(y_data), left, right, epsilon, scale_data,
bias_data, mean_data, variance_data);
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External( platform::errors::External(
"XPU layer_norm kernel return wrong value[%d %s]", r, "XPU layer_norm kernel return wrong value[%d %s]", r,
...@@ -57,6 +61,8 @@ class LayerNormXPUKernel : public framework::OpKernel<T> { ...@@ -57,6 +61,8 @@ class LayerNormXPUKernel : public framework::OpKernel<T> {
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class LayerNormGradXPUKernel : public framework::OpKernel<T> { class LayerNormGradXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis"); const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
...@@ -75,19 +81,24 @@ class LayerNormGradXPUKernel : public framework::OpKernel<T> { ...@@ -75,19 +81,24 @@ class LayerNormGradXPUKernel : public framework::OpKernel<T> {
auto* dbias = ctx.Output<Tensor>(framework::GradVarName("Bias")); auto* dbias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
const auto* x_data = x->data<T>(); const auto* x_data = x->data<T>();
const auto* dy_data = dy->data<T>(); const auto* dy_data = dy->data<T>();
const auto* mean_data = mean->data<T>(); const auto* mean_data = mean->data<float>();
const auto* variance_data = variance->data<T>(); const auto* variance_data = variance->data<float>();
const auto* scale_data = (scale == nullptr ? nullptr : scale->data<T>()); const auto* scale_data =
(scale == nullptr ? nullptr : scale->data<float>());
auto* dscale_data = auto* dscale_data =
(dscale == nullptr ? nullptr : dscale->mutable_data<T>(ctx.GetPlace())); (dscale == nullptr ? nullptr
auto* dbias_data = : dscale->mutable_data<float>(ctx.GetPlace()));
(dbias == nullptr ? nullptr : dbias->mutable_data<T>(ctx.GetPlace())); auto* dbias_data = (dbias == nullptr ? nullptr : dbias->mutable_data<float>(
ctx.GetPlace()));
auto* dx_data = auto* dx_data =
(dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace())); (dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()));
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
int r = xpu::layer_norm_grad(dev_ctx.x_context(), x_data, dy_data, dx_data,
left, right, epsilon, scale_data, mean_data, int r = xpu::layer_norm_grad(
variance_data, dscale_data, dbias_data); dev_ctx.x_context(), reinterpret_cast<const XPUType*>(x_data),
reinterpret_cast<const XPUType*>(dy_data),
reinterpret_cast<XPUType*>(dx_data), left, right, epsilon, scale_data,
mean_data, variance_data, dscale_data, dbias_data);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS, r, XPU_SUCCESS,
platform::errors::External( platform::errors::External(
...@@ -103,9 +114,13 @@ namespace ops = paddle::operators; ...@@ -103,9 +114,13 @@ namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL( REGISTER_OP_XPU_KERNEL(
layer_norm, layer_norm,
ops::LayerNormXPUKernel<paddle::platform::XPUDeviceContext, float>); ops::LayerNormXPUKernel<paddle::platform::XPUDeviceContext, float>,
ops::LayerNormXPUKernel<paddle::platform::XPUDeviceContext,
paddle::platform::float16>);
REGISTER_OP_XPU_KERNEL( REGISTER_OP_XPU_KERNEL(
layer_norm_grad, layer_norm_grad,
ops::LayerNormGradXPUKernel<paddle::platform::XPUDeviceContext, float>); ops::LayerNormGradXPUKernel<paddle::platform::XPUDeviceContext, float>,
ops::LayerNormGradXPUKernel<paddle::platform::XPUDeviceContext,
paddle::platform::float16>);
#endif // PADDLE_WITH_XPU #endif // PADDLE_WITH_XPU
...@@ -23,24 +23,33 @@ namespace operators { ...@@ -23,24 +23,33 @@ namespace operators {
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class MeanXPUKernel : public framework::OpKernel<T> { class MeanXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* input = context.Input<Tensor>("X"); auto* input = context.Input<Tensor>("X");
auto* output = context.Output<Tensor>("Out"); auto* output = context.Output<Tensor>("Out");
output->mutable_data<T>(context.GetPlace()); output->mutable_data<T>(context.GetPlace());
auto& dev_ctx = context.template device_context<DeviceContext>(); auto& dev_ctx = context.template device_context<DeviceContext>();
const float* x_data = input->data<float>(); const T* x_data = input->data<T>();
float* y_data = output->data<float>(); T* y_data = output->data<T>();
int r = xpu::mean(dev_ctx.x_context(), x_data, y_data, input->numel()); std::vector<int> x_shape;
PADDLE_ENFORCE_EQ( x_shape.push_back(1);
r, xpu::Error_t::SUCCESS, x_shape.push_back(input->numel());
platform::errors::External( std::vector<int> rdims = {1};
"XPU kernel error, Mean op execution not succeed, error code=%d", int r = xpu::reduce_mean(
r)); dev_ctx.x_context(), reinterpret_cast<const XPUType*>(x_data),
reinterpret_cast<XPUType*>(y_data), x_shape, rdims);
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External(
"XPU reduce_mean kernel return wrong value[%d %s]", r,
XPUAPIErrorMsg[r]));
} }
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class MeanGradXPUKernel : public framework::OpKernel<T> { class MeanGradXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto OG = context.Input<Tensor>(framework::GradVarName("Out")); auto OG = context.Input<Tensor>(framework::GradVarName("Out"));
...@@ -49,14 +58,24 @@ class MeanGradXPUKernel : public framework::OpKernel<T> { ...@@ -49,14 +58,24 @@ class MeanGradXPUKernel : public framework::OpKernel<T> {
auto IG = context.Output<Tensor>(framework::GradVarName("X")); auto IG = context.Output<Tensor>(framework::GradVarName("X"));
IG->mutable_data<T>(context.GetPlace()); IG->mutable_data<T>(context.GetPlace());
auto& dev_ctx = context.template device_context<DeviceContext>(); auto& dev_ctx = context.template device_context<DeviceContext>();
float* dx = IG->data<float>();
const float* dy = OG->data<float>(); XPUType* dx = reinterpret_cast<XPUType*>(IG->data<T>());
int r = xpu::mean_grad(dev_ctx.x_context(), dx, dy, IG->numel());
PADDLE_ENFORCE_EQ( const T* dy = OG->data<T>();
r, xpu::Error_t::SUCCESS, T dy0_value;
platform::errors::External( xpu_wait(dev_ctx.x_context()->xpu_stream);
"XPU kernel error. Mean_grad execution not succeed, error code=%d", memory::Copy(platform::CPUPlace(), &dy0_value,
r)); BOOST_GET_CONST(platform::XPUPlace, OG->place()), dy,
sizeof(T));
float dy0_fp32 = static_cast<float>(dy0_value);
dy0_fp32 = dy0_fp32 / static_cast<float>(IG->numel());
int r = xpu::constant(dev_ctx.x_context(), dx, IG->numel(),
static_cast<XPUType>(dy0_fp32));
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External(
"XPU constant kernel return wrong value[%d %s]", r,
XPUAPIErrorMsg[r]));
} }
}; };
...@@ -65,8 +84,12 @@ class MeanGradXPUKernel : public framework::OpKernel<T> { ...@@ -65,8 +84,12 @@ class MeanGradXPUKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL( REGISTER_OP_XPU_KERNEL(
mean, ops::MeanXPUKernel<paddle::platform::XPUDeviceContext, float>); mean, ops::MeanXPUKernel<paddle::platform::XPUDeviceContext, float>,
ops::MeanXPUKernel<paddle::platform::XPUDeviceContext,
paddle::platform::float16>);
REGISTER_OP_XPU_KERNEL( REGISTER_OP_XPU_KERNEL(
mean_grad, mean_grad,
ops::MeanGradXPUKernel<paddle::platform::XPUDeviceContext, float>); ops::MeanGradXPUKernel<paddle::platform::XPUDeviceContext, float>,
ops::MeanGradXPUKernel<paddle::platform::XPUDeviceContext,
paddle::platform::float16>);
#endif #endif
...@@ -113,27 +113,27 @@ class AdamOpXPUKernel : public framework::OpKernel<T> { ...@@ -113,27 +113,27 @@ class AdamOpXPUKernel : public framework::OpKernel<T> {
bool use_global_beta_pow = ctx.Attr<bool>("use_global_beta_pow"); bool use_global_beta_pow = ctx.Attr<bool>("use_global_beta_pow");
VLOG(4) << "use_global_beta_pow:" << use_global_beta_pow; VLOG(4) << "use_global_beta_pow:" << use_global_beta_pow;
T beta1 = static_cast<T>(ctx.Attr<float>("beta1")); float beta1 = static_cast<float>(ctx.Attr<float>("beta1"));
if (ctx.HasInput("Beta1Tensor")) { if (ctx.HasInput("Beta1Tensor")) {
auto* beta1_tensor = ctx.Input<framework::Tensor>("Beta1Tensor"); auto* beta1_tensor = ctx.Input<framework::Tensor>("Beta1Tensor");
beta1 = static_cast<T>(GetAttrFromTensor(beta1_tensor)); beta1 = static_cast<float>(GetAttrFromTensor(beta1_tensor));
} }
T beta2 = static_cast<T>(ctx.Attr<float>("beta2")); float beta2 = static_cast<float>(ctx.Attr<float>("beta2"));
if (ctx.HasInput("Beta2Tensor")) { if (ctx.HasInput("Beta2Tensor")) {
auto* beta2_tensor = ctx.Input<framework::Tensor>("Beta2Tensor"); auto* beta2_tensor = ctx.Input<framework::Tensor>("Beta2Tensor");
beta2 = static_cast<T>(GetAttrFromTensor(beta2_tensor)); beta2 = static_cast<float>(GetAttrFromTensor(beta2_tensor));
} }
T epsilon = static_cast<T>(ctx.Attr<float>("epsilon")); float epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
if (ctx.HasInput("EpsilonTensor")) { if (ctx.HasInput("EpsilonTensor")) {
auto* epsilon_tensor = ctx.Input<framework::Tensor>("EpsilonTensor"); auto* epsilon_tensor = ctx.Input<framework::Tensor>("EpsilonTensor");
epsilon = static_cast<T>(GetAttrFromTensor(epsilon_tensor)); epsilon = static_cast<float>(GetAttrFromTensor(epsilon_tensor));
} }
if (grad_var->IsType<framework::LoDTensor>()) { if (grad_var->IsType<framework::LoDTensor>()) {
auto& grad = GET_DATA_SAFELY(ctx.Input<LoDTensor>("Grad"), "Input", auto& grad = GET_DATA_SAFELY(ctx.Input<LoDTensor>("Grad"), "Input",
"Grad", "Adam"); "Grad", "Adam");
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
const T* beta1_pow_ptr = beta1_pow.template data<T>(); const float* beta1_pow_ptr = beta1_pow.template data<float>();
const T* beta2_pow_ptr = beta2_pow.template data<T>(); const float* beta2_pow_ptr = beta2_pow.template data<float>();
Tensor xpu_beta1_pow; Tensor xpu_beta1_pow;
Tensor xpu_beta2_pow; Tensor xpu_beta2_pow;
if (beta1_pow.place() == platform::CPUPlace() && if (beta1_pow.place() == platform::CPUPlace() &&
...@@ -141,50 +141,49 @@ class AdamOpXPUKernel : public framework::OpKernel<T> { ...@@ -141,50 +141,49 @@ class AdamOpXPUKernel : public framework::OpKernel<T> {
TensorCopy(beta1_pow, ctx.GetPlace(), dev_ctx, &xpu_beta1_pow); TensorCopy(beta1_pow, ctx.GetPlace(), dev_ctx, &xpu_beta1_pow);
TensorCopy(beta2_pow, ctx.GetPlace(), dev_ctx, &xpu_beta2_pow); TensorCopy(beta2_pow, ctx.GetPlace(), dev_ctx, &xpu_beta2_pow);
dev_ctx.Wait(); dev_ctx.Wait();
beta1_pow_ptr = xpu_beta1_pow.template data<T>(); beta1_pow_ptr = xpu_beta1_pow.template data<float>();
beta2_pow_ptr = xpu_beta2_pow.template data<T>(); beta2_pow_ptr = xpu_beta2_pow.template data<float>();
} }
int r = xpu::adam(
dev_ctx.x_context(), grad.template data<T>(), mom1.template data<T>(), int r = xpu::adam(dev_ctx.x_context(), grad.template data<T>(),
mom2.template data<T>(), param.template data<T>(), beta1_pow_ptr, mom1.template data<T>(), mom2.template data<T>(),
beta2_pow_ptr, beta1, beta2, epsilon, lr.template data<T>(), param.template data<float>(), beta1_pow_ptr,
mom1_out.template mutable_data<T>(ctx.GetPlace()), beta2_pow_ptr, lr.template data<float>(),
mom2_out.template mutable_data<T>(ctx.GetPlace()), mom1_out.template mutable_data<float>(ctx.GetPlace()),
param_out.template mutable_data<T>(ctx.GetPlace()), param.numel()); mom2_out.template mutable_data<float>(ctx.GetPlace()),
param_out.template mutable_data<float>(ctx.GetPlace()),
beta1, beta2, epsilon, param.numel());
if (!use_global_beta_pow) { if (!use_global_beta_pow) {
// update in cpu and then copy to xpu // update in cpu and then copy to xpu
if (beta1_pow.place() == platform::CPUPlace() && if (beta1_pow.place() == platform::CPUPlace() &&
beta2_pow.place() == platform::CPUPlace()) { beta2_pow.place() == platform::CPUPlace()) {
const T* beta1_pow_p = beta1_pow.template data<T>(); const float* beta1_pow_p = beta1_pow.template data<float>();
beta1_pow_out->mutable_data<T>(platform::CPUPlace())[0] = beta1_pow_out->mutable_data<float>(platform::CPUPlace())[0] =
beta1 * beta1_pow_p[0]; beta1 * beta1_pow_p[0];
const T* beta2_pow_p = beta2_pow.template data<T>(); const float* beta2_pow_p = beta2_pow.template data<float>();
beta2_pow_out->mutable_data<T>(platform::CPUPlace())[0] = beta2_pow_out->mutable_data<float>(platform::CPUPlace())[0] =
beta2 * beta2_pow_p[0]; beta2 * beta2_pow_p[0];
xpu_wait(dev_ctx.x_context()->xpu_stream);
} else { } else {
T cpu_beta1_pow_out_data; float* beta1_pow_out_p =
T cpu_beta2_pow_out_data; beta1_pow_out->mutable_data<float>(ctx.GetPlace());
float* beta2_pow_out_p =
memory::Copy(platform::CPUPlace(), &cpu_beta1_pow_out_data, beta2_pow_out->mutable_data<float>(ctx.GetPlace());
BOOST_GET_CONST(platform::XPUPlace, beta1_pow.place()), int r =
beta1_pow_ptr, sizeof(T)); xpu::scale(dev_ctx.x_context(), beta1_pow_ptr, beta1_pow_out_p,
beta1_pow.numel(), false, beta1, 0.0f);
cpu_beta1_pow_out_data = cpu_beta1_pow_out_data * beta1; PADDLE_ENFORCE_EQ(
memory::Copy(platform::CPUPlace(), &cpu_beta2_pow_out_data, r, xpu::SUCCESS,
BOOST_GET_CONST(platform::XPUPlace, beta2_pow.place()), platform::errors::External(
beta2_pow_ptr, sizeof(T)); "XPU kernel scale occur error in adamw error code ", r,
XPUAPIErrorMsg[r]));
cpu_beta2_pow_out_data = cpu_beta2_pow_out_data * beta2; r = xpu::scale(dev_ctx.x_context(), beta2_pow_ptr, beta2_pow_out_p,
beta2_pow.numel(), false, beta2, 0.0f);
T* beta1_pow_out_p = beta1_pow_out->mutable_data<T>(ctx.GetPlace()); PADDLE_ENFORCE_EQ(
T* beta2_pow_out_p = beta2_pow_out->mutable_data<T>(ctx.GetPlace()); r, xpu::SUCCESS,
memory::Copy(BOOST_GET_CONST(platform::XPUPlace, ctx.GetPlace()), platform::errors::External(
beta1_pow_out_p, platform::CPUPlace(), "XPU kernel scale occur error in adamw error code ", r,
&cpu_beta1_pow_out_data, sizeof(T)); XPUAPIErrorMsg[r]));
memory::Copy(BOOST_GET_CONST(platform::XPUPlace, ctx.GetPlace()),
beta2_pow_out_p, platform::CPUPlace(),
&cpu_beta2_pow_out_data, sizeof(T));
} }
PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true, PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true,
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "gflags/gflags.h"
#include "paddle/fluid/operators/optimizers/adam_op.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
#ifdef PADDLE_WITH_XPU
template <typename T>
class AdamwOpXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto* param_var = ctx.InputVar("Param");
PADDLE_ENFORCE_EQ(param_var->IsType<framework::LoDTensor>(), true,
platform::errors::InvalidArgument(
"Tensor holds the wrong type,Expected Var(%s)'s "
"type is LoDTensor, "
"but the received is %s",
ctx.InputNames("Param").front(),
framework::ToTypeName(param_var->Type())));
using paddle::framework::LoDTensor;
auto& param = GET_DATA_SAFELY(ctx.Input<LoDTensor>("Param"), "Input",
"Param", "Adam");
// auto& grad = Ref(ctx.Input<LoDTensor>("Grad"), "Must set Grad");
auto* grad_var = ctx.InputVar("Grad");
auto& mom1 = GET_DATA_SAFELY(ctx.Input<LoDTensor>("Moment1"), "Input",
"Moment1", "Adam");
auto& mom2 = GET_DATA_SAFELY(ctx.Input<LoDTensor>("Moment2"), "Input",
"Moment2", "Adam");
auto& lr = GET_DATA_SAFELY(ctx.Input<LoDTensor>("LearningRate"), "Input",
"LearningRate", "Adam");
auto& beta1_pow = GET_DATA_SAFELY(ctx.Input<LoDTensor>("Beta1Pow"), "Input",
"Beta1Pow", "Adam");
auto& beta2_pow = GET_DATA_SAFELY(ctx.Input<LoDTensor>("Beta2Pow"), "Input",
"Beta2Pow", "Adam");
auto& param_out = GET_DATA_SAFELY(ctx.Output<LoDTensor>("ParamOut"),
"Output", "ParamOut", "Adam");
auto& mom1_out = GET_DATA_SAFELY(ctx.Output<LoDTensor>("Moment1Out"),
"Output", "Moment1Out", "Adam");
auto& mom2_out = GET_DATA_SAFELY(ctx.Output<LoDTensor>("Moment2Out"),
"Output", "Moment2Out", "Adam");
auto* beta1_pow_out = ctx.Output<LoDTensor>("Beta1PowOut");
auto* beta2_pow_out = ctx.Output<LoDTensor>("Beta2PowOut");
bool skip_update = false;
if (ctx.HasInput("SkipUpdate")) {
auto* skip_update_tensor = ctx.Input<framework::Tensor>("SkipUpdate");
PADDLE_ENFORCE_EQ(skip_update_tensor->numel(), 1,
platform::errors::InvalidArgument(
"Input(SkipUpdate) size must be 1, but get %d",
skip_update_tensor->numel()));
std::vector<bool> skip_update_vec;
TensorToVector(*skip_update_tensor, ctx.device_context(),
&skip_update_vec);
skip_update = skip_update_vec[0];
}
auto& dev_ctx = ctx.template device_context<platform::XPUDeviceContext>();
// skip_update=true, just copy input to output, and TensorCopy will call
// mutable_data
if (skip_update) {
VLOG(4) << "Adam skip update";
framework::TensorCopy(param, ctx.GetPlace(), dev_ctx, &param_out);
framework::TensorCopy(mom1, ctx.GetPlace(), dev_ctx, &mom1_out);
framework::TensorCopy(mom2, ctx.GetPlace(), dev_ctx, &mom2_out);
framework::TensorCopy(beta1_pow, ctx.GetPlace(), dev_ctx, beta1_pow_out);
framework::TensorCopy(beta2_pow, ctx.GetPlace(), dev_ctx, beta2_pow_out);
return;
}
bool with_decay = ctx.Attr<bool>("with_decay");
PADDLE_ENFORCE_EQ(beta1_pow_out->numel(), 1,
platform::errors::InvalidArgument(
"Tensor holds the wrong size, Expected beta1 pow "
"output size is 1, but received "
"value is:%d.",
beta1_pow_out->numel()));
PADDLE_ENFORCE_EQ(beta2_pow_out->numel(), 1,
platform::errors::InvalidArgument(
"Tensor holds the wrong size, Expected beta2 pow "
"output size is 1, but received "
"value is:%d.",
beta2_pow_out->numel()));
bool use_global_beta_pow = ctx.Attr<bool>("use_global_beta_pow");
VLOG(4) << "use_global_beta_pow:" << use_global_beta_pow;
float beta1 = static_cast<float>(ctx.Attr<float>("beta1"));
if (ctx.HasInput("Beta1Tensor")) {
auto* beta1_tensor = ctx.Input<framework::Tensor>("Beta1Tensor");
beta1 = static_cast<float>(GetAttrFromTensor(beta1_tensor));
}
float beta2 = static_cast<float>(ctx.Attr<float>("beta2"));
if (ctx.HasInput("Beta2Tensor")) {
auto* beta2_tensor = ctx.Input<framework::Tensor>("Beta2Tensor");
beta2 = static_cast<float>(GetAttrFromTensor(beta2_tensor));
}
float epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
if (ctx.HasInput("EpsilonTensor")) {
auto* epsilon_tensor = ctx.Input<framework::Tensor>("EpsilonTensor");
epsilon = static_cast<float>(GetAttrFromTensor(epsilon_tensor));
}
if (grad_var->IsType<framework::LoDTensor>()) {
auto& grad = GET_DATA_SAFELY(ctx.Input<LoDTensor>("Grad"), "Input",
"Grad", "Adam");
const float* beta1_pow_ptr = beta1_pow.template data<float>();
const float* beta2_pow_ptr = beta2_pow.template data<float>();
Tensor xpu_beta1_pow;
Tensor xpu_beta2_pow;
if (beta1_pow.place() == platform::CPUPlace() &&
beta2_pow.place() == platform::CPUPlace()) {
TensorCopy(beta1_pow, ctx.GetPlace(), dev_ctx, &xpu_beta1_pow);
TensorCopy(beta2_pow, ctx.GetPlace(), dev_ctx, &xpu_beta2_pow);
dev_ctx.Wait();
beta1_pow_ptr = xpu_beta1_pow.template data<float>();
beta2_pow_ptr = xpu_beta2_pow.template data<float>();
}
if (with_decay) {
float coeff = ctx.Attr<float>("coeff");
int r =
xpu::adamw(dev_ctx.x_context(), grad.template data<T>(),
mom1.template data<float>(), mom2.template data<float>(),
param.template data<T>(), beta1_pow_ptr, beta2_pow_ptr,
lr.template data<float>(),
mom1_out.template mutable_data<float>(ctx.GetPlace()),
mom2_out.template mutable_data<float>(ctx.GetPlace()),
param_out.template mutable_data<T>(ctx.GetPlace()),
beta1, beta2, epsilon, coeff, param.numel());
PADDLE_ENFORCE_EQ(
r, xpu::SUCCESS,
platform::errors::External(
"XPU kernel adamw occur error in adamw error code ", r,
XPUAPIErrorMsg[r]));
} else {
int r =
xpu::adam(dev_ctx.x_context(), grad.template data<T>(),
mom1.template data<float>(), mom2.template data<float>(),
param.template data<T>(), beta1_pow_ptr, beta2_pow_ptr,
lr.template data<float>(),
mom1_out.template mutable_data<float>(ctx.GetPlace()),
mom2_out.template mutable_data<float>(ctx.GetPlace()),
param_out.template mutable_data<T>(ctx.GetPlace()), beta1,
beta2, epsilon, param.numel());
PADDLE_ENFORCE_EQ(
r, xpu::SUCCESS,
platform::errors::External(
"XPU kernel adam occur error in adamw error code ", r,
XPUAPIErrorMsg[r]));
}
if (!use_global_beta_pow) {
// update in cpu and then copy to xpu
if (beta1_pow.place() == platform::CPUPlace() &&
beta2_pow.place() == platform::CPUPlace()) {
const float* beta1_pow_p = beta1_pow.template data<float>();
beta1_pow_out->mutable_data<float>(platform::CPUPlace())[0] =
beta1 * beta1_pow_p[0];
const float* beta2_pow_p = beta2_pow.template data<float>();
beta2_pow_out->mutable_data<float>(platform::CPUPlace())[0] =
beta2 * beta2_pow_p[0];
xpu_wait(dev_ctx.x_context()->xpu_stream);
} else {
float* beta1_pow_out_p =
beta1_pow_out->mutable_data<float>(ctx.GetPlace());
float* beta2_pow_out_p =
beta2_pow_out->mutable_data<float>(ctx.GetPlace());
int r =
xpu::scale(dev_ctx.x_context(), beta1_pow_ptr, beta1_pow_out_p,
beta1_pow.numel(), false, beta1, 0.0f);
PADDLE_ENFORCE_EQ(
r, xpu::SUCCESS,
platform::errors::External(
"XPU kernel scale occur error in adamw error code ", r,
XPUAPIErrorMsg[r]));
r = xpu::scale(dev_ctx.x_context(), beta2_pow_ptr, beta2_pow_out_p,
beta2_pow.numel(), false, beta2, 0.0f);
PADDLE_ENFORCE_EQ(
r, xpu::SUCCESS,
platform::errors::External(
"XPU kernel scale occur error in adamw error code ", r,
XPUAPIErrorMsg[r]));
}
}
} else {
PADDLE_ENFORCE_EQ(1, 2, platform::errors::InvalidArgument(
"Variable type not supported by adamw_op"));
}
}
};
#endif
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
#ifdef PADDLE_WITH_XPU
REGISTER_OP_XPU_KERNEL(adamw, ops::AdamwOpXPUKernel<float>);
#endif
...@@ -54,9 +54,11 @@ class SoftmaxWithCrossEntropyXPUKernel : public framework::OpKernel<T> { ...@@ -54,9 +54,11 @@ class SoftmaxWithCrossEntropyXPUKernel : public framework::OpKernel<T> {
int len = logits->numel(); int len = logits->numel();
T* clip_logits_data = T* clip_logits_data =
clip_logits.mutable_data<T>(context.GetPlace(), len * sizeof(T)); clip_logits.mutable_data<T>(context.GetPlace(), len * sizeof(T));
r = xpu::clip_v2(dev_ctx.x_context(), logits->data<float>(), r = xpu::clip_v2(dev_ctx.x_context(), logits->data<float>(),
clip_logits_data, len, static_cast<float>(-1e20), clip_logits_data, len, static_cast<float>(-1e20),
static_cast<float>(1e20)); static_cast<float>(1e20));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
r, xpu::Error_t::SUCCESS, r, xpu::Error_t::SUCCESS,
platform::errors::External("XPU kernel error. clip " platform::errors::External("XPU kernel error. clip "
...@@ -108,10 +110,88 @@ class SoftmaxWithCrossEntropyXPUKernel : public framework::OpKernel<T> { ...@@ -108,10 +110,88 @@ class SoftmaxWithCrossEntropyXPUKernel : public framework::OpKernel<T> {
} }
} }
}; };
template <typename T>
class SoftmaxWithCrossEntropyGradXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* out_grad =
context.Input<Tensor>(framework::GradVarName("Loss"));
const Tensor* labels = context.Input<Tensor>("Label");
Tensor* logit_grad =
context.Output<Tensor>(framework::GradVarName("Logits"));
logit_grad->mutable_data<T>(context.GetPlace());
const Tensor* softmax = context.Input<Tensor>("Softmax");
const bool use_softmax = context.Attr<bool>("use_softmax");
const bool soft_label = context.Attr<bool>("soft_label");
auto ignore_index = context.Attr<int>("ignore_index");
const int rank = logit_grad->dims().size();
const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
PADDLE_ENFORCE_EQ(axis, rank - 1, platform::errors::InvalidArgument(
"axis should == rank - 1"));
const int n = SizeToAxis(axis, logit_grad->dims());
const int d = SizeFromAxis(axis, logit_grad->dims());
auto& dev_ctx =
context.template device_context<platform::XPUDeviceContext>();
int r = XPU_SUCCESS;
if (soft_label) {
r = xpu::soft_softmax_with_cross_entropy_grad<XPUType>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(out_grad->data<T>()),
reinterpret_cast<const XPUType*>(labels->data<T>()),
reinterpret_cast<const XPUType*>(softmax->data<T>()),
reinterpret_cast<XPUType*>(logit_grad->data<T>()), use_softmax, n, d);
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External(
"XPU API(soft_softmax_with_cross_entropy_grad) return wrong "
"value[%d %s]",
r, XPUAPIErrorMsg[r]));
} else {
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
int* labels_int_ptr_l3 =
RAII_GUARD.alloc_l3_or_gm<int32_t>(labels->numel());
r = xpu::cast_v2<int64_t, int32_t>(dev_ctx.x_context(),
labels->data<int64_t>(),
labels_int_ptr_l3, labels->numel());
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External(
"XPU API(cast_v2) return wrong "
"value[%d %s]",
r, XPUAPIErrorMsg[r]));
r = xpu::hard_softmax_with_cross_entropy_grad<XPUType, int>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(out_grad->data<T>()),
labels_int_ptr_l3,
reinterpret_cast<const XPUType*>(softmax->data<T>()),
reinterpret_cast<XPUType*>(logit_grad->data<T>()), ignore_index,
use_softmax, n, d);
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External(
"XPU API(hard_softmax_with_cross_entropy_grad) return wrong "
"value[%d %s]",
r, XPUAPIErrorMsg[r]));
}
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(softmax_with_cross_entropy, REGISTER_OP_XPU_KERNEL(softmax_with_cross_entropy,
ops::SoftmaxWithCrossEntropyXPUKernel<float>); ops::SoftmaxWithCrossEntropyXPUKernel<float>);
REGISTER_OP_XPU_KERNEL(
softmax_with_cross_entropy_grad,
ops::SoftmaxWithCrossEntropyGradXPUKernel<float>,
ops::SoftmaxWithCrossEntropyGradXPUKernel<paddle::platform::float16>);
#endif #endif
...@@ -21,6 +21,8 @@ using framework::Tensor; ...@@ -21,6 +21,8 @@ using framework::Tensor;
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class SumXPUKernel : public framework::OpKernel<T> { class SumXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public: public:
void Compute(const framework::ExecutionContext &context) const override { void Compute(const framework::ExecutionContext &context) const override {
auto in_vars = context.MultiInputVar("X"); auto in_vars = context.MultiInputVar("X");
...@@ -35,8 +37,7 @@ class SumXPUKernel : public framework::OpKernel<T> { ...@@ -35,8 +37,7 @@ class SumXPUKernel : public framework::OpKernel<T> {
out->mutable_data<T>(context.GetPlace()); out->mutable_data<T>(context.GetPlace());
} }
auto &dev_ctx = context.template device_context<DeviceContext>(); auto &dev_ctx = context.template device_context<DeviceContext>();
std::vector<const float *> ptrs(N, nullptr); std::vector<const XPUType *> ptrs;
int valid_count = 0;
for (int i = 0; i < N; ++i) { for (int i = 0; i < N; ++i) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
in_vars[i]->IsType<framework::LoDTensor>(), true, in_vars[i]->IsType<framework::LoDTensor>(), true,
...@@ -45,30 +46,14 @@ class SumXPUKernel : public framework::OpKernel<T> { ...@@ -45,30 +46,14 @@ class SumXPUKernel : public framework::OpKernel<T> {
if (in_t.numel() == 0) { if (in_t.numel() == 0) {
continue; continue;
} }
ptrs[valid_count] = reinterpret_cast<const float *>(in_t.data<T>()); ptrs.push_back(reinterpret_cast<const XPUType *>(in_t.data<T>()));
valid_count++;
}
int r = xpu::sum_batch(dev_ctx.x_context(), ptrs.data(), out->data<T>(),
valid_count, out->numel());
if (r == xpu::Error_t::INVALID_PARAM) {
PADDLE_ENFORCE_EQ(
r, xpu::Error_t::SUCCESS,
platform::errors::InvalidArgument(
"XPU kernel error of SumOp, error message: INVALID_PARAM, "
"please check your input & output."));
} else if (r == xpu::Error_t::RUNTIME_ERROR) {
PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS,
platform::errors::Unavailable(
"XPU kernel error of SumOp, error message: "
"RUNTIME_ERROR, please check whether Baidu "
"Kunlun Card is properly installed."));
} else if (r == xpu::Error_t::NO_ENOUGH_WORKSPACE) {
PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS,
platform::errors::ResourceExhausted(
"XPU kernel error of SumOp, error "
"message: NO_ENOUGH_WORKSPACE, XPU "
"has no enough memory."));
} }
int r = xpu::sum(dev_ctx.x_context(), ptrs,
reinterpret_cast<XPUType *>(out->data<T>()), out->numel());
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External("XPU sum kernel return wrong value[%d %s]",
r, XPUAPIErrorMsg[r]));
} }
}; };
...@@ -78,5 +63,7 @@ class SumXPUKernel : public framework::OpKernel<T> { ...@@ -78,5 +63,7 @@ class SumXPUKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL( REGISTER_OP_XPU_KERNEL(
sum, ops::SumXPUKernel<paddle::platform::XPUDeviceContext, float>); sum, ops::SumXPUKernel<paddle::platform::XPUDeviceContext, float>,
ops::SumXPUKernel<paddle::platform::XPUDeviceContext,
paddle::platform::float16>);
#endif #endif
...@@ -26,6 +26,8 @@ using framework::Tensor; ...@@ -26,6 +26,8 @@ using framework::Tensor;
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class TransposeXPUKernel : public framework::OpKernel<T> { class TransposeXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto x = context.Input<framework::Tensor>("X"); auto x = context.Input<framework::Tensor>("X");
...@@ -46,8 +48,9 @@ class TransposeXPUKernel : public framework::OpKernel<T> { ...@@ -46,8 +48,9 @@ class TransposeXPUKernel : public framework::OpKernel<T> {
x_shape_host[i] = x_dims[i]; x_shape_host[i] = x_dims[i];
} }
auto& dev_ctx = context.template device_context<DeviceContext>(); auto& dev_ctx = context.template device_context<DeviceContext>();
int r = xpu::transpose<T>(dev_ctx.x_context(), x_data, y_data, x_shape_host, int r = xpu::transpose<XPUType>(
axis); dev_ctx.x_context(), reinterpret_cast<const XPUType*>(x_data),
reinterpret_cast<XPUType*>(y_data), x_shape_host, axis);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
r, xpu::Error_t::SUCCESS, r, xpu::Error_t::SUCCESS,
platform::errors::External("XPU kernel error! error code=%d", r)); platform::errors::External("XPU kernel error! error code=%d", r));
...@@ -56,6 +59,8 @@ class TransposeXPUKernel : public framework::OpKernel<T> { ...@@ -56,6 +59,8 @@ class TransposeXPUKernel : public framework::OpKernel<T> {
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class TransposeGradXPUKernel : public framework::OpKernel<T> { class TransposeGradXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* out_grad = auto* out_grad =
...@@ -77,8 +82,11 @@ class TransposeGradXPUKernel : public framework::OpKernel<T> { ...@@ -77,8 +82,11 @@ class TransposeGradXPUKernel : public framework::OpKernel<T> {
out_shape_host[i] = out_grad->dims()[i]; out_shape_host[i] = out_grad->dims()[i];
} }
auto& dev_ctx = context.template device_context<DeviceContext>(); auto& dev_ctx = context.template device_context<DeviceContext>();
int r = xpu::transpose<T>(dev_ctx.x_context(), out_grad->data<T>(), int r = xpu::transpose<XPUType>(
x_grad->data<T>(), out_shape_host, reversed_axis); dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(out_grad->data<T>()),
reinterpret_cast<XPUType*>(x_grad->data<T>()), out_shape_host,
reversed_axis);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
r, xpu::Error_t::SUCCESS, r, xpu::Error_t::SUCCESS,
platform::errors::External("XPU kernel error! error code=%d", r)); platform::errors::External("XPU kernel error! error code=%d", r));
...@@ -92,15 +100,23 @@ namespace ops = paddle::operators; ...@@ -92,15 +100,23 @@ namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL( REGISTER_OP_XPU_KERNEL(
transpose, transpose,
ops::TransposeXPUKernel<paddle::platform::XPUDeviceContext, float>); ops::TransposeXPUKernel<paddle::platform::XPUDeviceContext, float>,
ops::TransposeXPUKernel<paddle::platform::XPUDeviceContext,
paddle::platform::float16>);
REGISTER_OP_XPU_KERNEL( REGISTER_OP_XPU_KERNEL(
transpose_grad, transpose_grad,
ops::TransposeGradXPUKernel<paddle::platform::XPUDeviceContext, float>); ops::TransposeGradXPUKernel<paddle::platform::XPUDeviceContext, float>,
ops::TransposeGradXPUKernel<paddle::platform::XPUDeviceContext,
paddle::platform::float16>);
REGISTER_OP_XPU_KERNEL( REGISTER_OP_XPU_KERNEL(
transpose2, transpose2,
ops::TransposeXPUKernel<paddle::platform::XPUDeviceContext, float>); ops::TransposeXPUKernel<paddle::platform::XPUDeviceContext, float>,
ops::TransposeXPUKernel<paddle::platform::XPUDeviceContext,
paddle::platform::float16>);
REGISTER_OP_XPU_KERNEL( REGISTER_OP_XPU_KERNEL(
transpose2_grad, transpose2_grad,
ops::TransposeGradXPUKernel<paddle::platform::XPUDeviceContext, float>); ops::TransposeGradXPUKernel<paddle::platform::XPUDeviceContext, float>,
ops::TransposeGradXPUKernel<paddle::platform::XPUDeviceContext,
paddle::platform::float16>);
#endif // PADDLE_WITH_XPU #endif // PADDLE_WITH_XPU
...@@ -79,6 +79,35 @@ XPUOpMap& get_kl2_ops() { ...@@ -79,6 +79,35 @@ XPUOpMap& get_kl2_ops() {
{"batch_norm", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"batch_norm", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"batch_norm_grad", {"batch_norm_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"layer_norm", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"layer_norm_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"mean", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"mean_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"adam", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"adamw", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"softmax_with_cross_entropy",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"softmax_with_cross_entropy_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"sum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"transpose", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"transpose_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"transpose2", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"transpose2_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
// AddMore // AddMore
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册