未验证 提交 ff96a7d5 编写于 作者: T taixiurong 提交者: GitHub

update elementwise api in kunlun (#35021)

上级 881e55e4
...@@ -35,7 +35,7 @@ ELSE () ...@@ -35,7 +35,7 @@ ELSE ()
ENDIF() ENDIF()
SET(XPU_BASE_URL_WITHOUT_DATE "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev") SET(XPU_BASE_URL_WITHOUT_DATE "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev")
SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20210804") SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20210818")
SET(XPU_XRE_URL "${XPU_BASE_URL}/${XPU_XRE_DIR_NAME}.tar.gz" CACHE STRING "" FORCE) SET(XPU_XRE_URL "${XPU_BASE_URL}/${XPU_XRE_DIR_NAME}.tar.gz" CACHE STRING "" FORCE)
SET(XPU_XDNN_URL "${XPU_BASE_URL}/${XPU_XDNN_DIR_NAME}.tar.gz" CACHE STRING "" FORCE) SET(XPU_XDNN_URL "${XPU_BASE_URL}/${XPU_XDNN_DIR_NAME}.tar.gz" CACHE STRING "" FORCE)
SET(XPU_XCCL_URL "${XPU_BASE_URL_WITHOUT_DATE}/20210623/${XPU_XCCL_DIR_NAME}.tar.gz" CACHE STRING "" FORCE) SET(XPU_XCCL_URL "${XPU_BASE_URL_WITHOUT_DATE}/20210623/${XPU_XCCL_DIR_NAME}.tar.gz" CACHE STRING "" FORCE)
......
...@@ -23,93 +23,45 @@ limitations under the License. */ ...@@ -23,93 +23,45 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename DeviceContext, typename T> template <typename T>
class ElementwiseAddXPUKernel : public framework::OpKernel<T> { class ElementwiseAddXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
// XPUElementwise<T>(ctx, xpu::add<T>); XPUElementwise<T, XPUType>(ctx, xpu::broadcast_add<XPUType>);
// ToDo(QingshuChen): update this optimization to elementwise_xpu.h
auto x_var = ctx.InputVar("X");
PADDLE_ENFORCE_NE(x_var, nullptr, platform::errors::InvalidArgument(
"Cannot get input Variable X"));
PADDLE_ENFORCE_EQ(
x_var->IsType<framework::LoDTensor>(), true,
platform::errors::InvalidArgument(
"XPU only support LoDTensor, Input(X) is not LoDTensor"));
auto x = x_var->Get<framework::LoDTensor>();
auto* y = ctx.Input<framework::LoDTensor>("Y");
auto* z = ctx.Output<framework::LoDTensor>("Out");
z->mutable_data<T>(ctx.GetPlace());
auto x_dims = x.dims();
auto y_dims = y->dims();
int max_dim = std::max(x_dims.size(), y_dims.size());
int axis = ctx.Attr<int>("axis");
axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis);
PADDLE_ENFORCE_GE(
axis, 0,
platform::errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.",
axis));
PADDLE_ENFORCE_LT(
axis, max_dim,
platform::errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.", max_dim,
axis));
std::vector<int> x_dims_vec(max_dim, 1);
std::vector<int> y_dims_vec(max_dim, 1);
if (x_dims.size() == max_dim) {
for (int i = 0; i < max_dim; i++) {
x_dims_vec[i] = x_dims[i];
}
} else {
for (int i = 0; i < x_dims.size(); i++) {
x_dims_vec[i + axis] = x_dims[i];
}
}
if (y_dims.size() == max_dim) {
for (int i = 0; i < max_dim; i++) {
y_dims_vec[i] = y_dims[i];
} }
} else { };
for (int i = 0; i < y_dims.size(); i++) {
y_dims_vec[i + axis] = y_dims[i];
}
}
const T* x_data = x.data<T>();
const T* y_data = y->data<T>();
T* z_data = z->data<T>();
auto& dev_ctx = static std::vector<int> get_rdims(const std::vector<int>& xdims,
ctx.template device_context<paddle::platform::XPUDeviceContext>(); const std::vector<int>& ydims) {
int ret = xpu::SUCCESS; std::vector<int> rdims;
ret = xpu::broadcast_add<T>(dev_ctx.x_context(), x_data, y_data, z_data, for (size_t i = 0; i < xdims.size(); i++) {
x_dims_vec, y_dims_vec); if (xdims[i] != ydims[i]) {
PADDLE_ENFORCE_EQ( rdims.push_back(i);
ret, xpu::SUCCESS,
platform::errors::External(
"XPU kernel Elementwise occur error in XPUElementwise error code ",
ret, XPUAPIErrorMsg[ret]));
} }
}; }
return rdims;
}
template <typename DeviceContext, typename T> template <typename T>
class ElementwiseAddGradXPUKernel : public ElemwiseGradKernel<T> { class ElementwiseAddGradXPUKernel : public ElemwiseGradKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx); ElemwiseGradKernel<T>::Compute(ctx);
// XPUElementwiseGrad<T>(ctx, xpu::add_grad<T>, false);
auto* x = ctx.Input<framework::Tensor>("X"); auto* x = ctx.Input<framework::Tensor>("X");
auto* y = ctx.Input<framework::Tensor>("Y"); auto* y = ctx.Input<framework::Tensor>("Y");
auto* dz = ctx.Input<framework::Tensor>(framework::GradVarName("Out")); auto* dz = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<framework::Tensor>(framework::GradVarName("X")); auto* dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<framework::Tensor>(framework::GradVarName("Y")); auto* dy = ctx.Output<framework::Tensor>(framework::GradVarName("Y"));
int axis = ctx.Attr<int>("axis");
const framework::DDim& x_dims = x->dims(); const framework::DDim& x_dims = x->dims();
const framework::DDim& y_dims = y->dims(); const framework::DDim& y_dims = y->dims();
int max_dim = std::max(x_dims.size(), y_dims.size()); const framework::DDim& dz_dims = dz->dims();
int axis = ctx.Attr<int>("axis");
axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis); axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis);
int max_dim = std::max(x_dims.size(), y_dims.size());
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
axis, 0, axis, 0,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
...@@ -120,66 +72,74 @@ class ElementwiseAddGradXPUKernel : public ElemwiseGradKernel<T> { ...@@ -120,66 +72,74 @@ class ElementwiseAddGradXPUKernel : public ElemwiseGradKernel<T> {
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.", max_dim, "Axis should be less than %d, but received axis is %d.", max_dim,
axis)); axis));
std::vector<int> x_dims_vec(max_dim, 1); std::vector<int> x_dims_vec(max_dim, 1);
std::vector<int> y_dims_vec(max_dim, 1); std::vector<int> y_dims_vec(max_dim, 1);
int x_len = 1; std::vector<int> z_dims_vec(max_dim, 1);
int y_len = 1;
if (x_dims.size() == max_dim) { if (x_dims.size() == max_dim) {
for (int i = 0; i < max_dim; i++) { for (int i = 0; i < max_dim; i++) {
x_dims_vec[i] = x_dims[i]; x_dims_vec[i] = x_dims[i];
x_len *= x_dims_vec[i];
} }
} else { } else {
for (int i = 0; i < x_dims.size(); i++) { for (int i = 0; i < x_dims.size(); i++) {
x_dims_vec[i + axis] = x_dims[i]; x_dims_vec[i + axis] = x_dims[i];
x_len *= x_dims_vec[i];
} }
} }
if (y_dims.size() == max_dim) { if (y_dims.size() == max_dim) {
for (int i = 0; i < max_dim; i++) { for (int i = 0; i < max_dim; i++) {
y_dims_vec[i] = y_dims[i]; y_dims_vec[i] = y_dims[i];
y_len *= y_dims_vec[i];
} }
} else { } else {
for (int i = 0; i < y_dims.size(); i++) { for (int i = 0; i < y_dims.size(); i++) {
y_dims_vec[i + axis] = y_dims[i]; y_dims_vec[i + axis] = y_dims[i];
y_len *= y_dims_vec[i];
} }
} }
for (int i = 0; i < max_dim; i++) {
z_dims_vec[i] = dz_dims[i];
}
std::vector<int> rdims_for_x;
std::vector<int> rdims_for_y;
rdims_for_x = get_rdims(x_dims_vec, z_dims_vec);
rdims_for_y = get_rdims(y_dims_vec, z_dims_vec);
const T* dz_data = dz->data<T>(); const T* dz_data = dz->data<T>();
framework::Tensor dx_local_tensor; auto& dev_ctx =
framework::Tensor dy_local_tensor; ctx.template device_context<paddle::platform::XPUDeviceContext>();
bool need_wait = false; if (dx != nullptr) {
T* dx_data = nullptr; if (rdims_for_x.size() == 0) {
T* dy_data = nullptr; framework::TensorCopy(
if (dx) { *dz, ctx.GetPlace(),
dx_data = dx->mutable_data<T>(ctx.GetPlace()); ctx.template device_context<platform::DeviceContext>(), dx);
} else { } else {
dx_data = T* dx_data = dx->mutable_data<T>(ctx.GetPlace());
dx_local_tensor.mutable_data<T>(ctx.GetPlace(), x_len * sizeof(T)); int ret = xpu::reduce_sum<XPUType>(
need_wait = true; dev_ctx.x_context(), reinterpret_cast<const XPUType*>(dz_data),
reinterpret_cast<XPUType*>(dx_data), z_dims_vec, rdims_for_x);
PADDLE_ENFORCE_EQ(
ret, xpu::SUCCESS,
platform::errors::External("XPU kernel reduce_sum occur error in "
"XPUElementwise error code ",
ret, XPUAPIErrorMsg[ret]));
} }
if (dy) {
dy_data = dy->mutable_data<T>(ctx.GetPlace());
} else {
dy_data =
dy_local_tensor.mutable_data<T>(ctx.GetPlace(), y_len * sizeof(T));
need_wait = true;
} }
auto& dev_ctx = if (dy != nullptr) {
ctx.template device_context<paddle::platform::XPUDeviceContext>(); if (rdims_for_y.size() == 0) {
int ret = xpu::broadcast_add_grad<T>(dev_ctx.x_context(), dz_data, dz_data, framework::TensorCopy(
dz_data, dz_data, dy_data, dx_data, *dz, ctx.GetPlace(),
x_dims_vec, y_dims_vec); ctx.template device_context<platform::DeviceContext>(), dy);
} else {
T* dy_data = dy->mutable_data<T>(ctx.GetPlace());
int ret = xpu::reduce_sum<XPUType>(
dev_ctx.x_context(), reinterpret_cast<const XPUType*>(dz_data),
reinterpret_cast<XPUType*>(dy_data), z_dims_vec, rdims_for_y);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
ret, xpu::SUCCESS, ret, xpu::SUCCESS,
platform::errors::External( platform::errors::External("XPU kernel reduce_sum occur error in "
"XPU kernel Elementwise occur error in XPUElementwise error code ", "XPUElementwise error code ",
ret, XPUAPIErrorMsg[ret])); ret, XPUAPIErrorMsg[ret]));
if (need_wait && dev_ctx.x_context()->xpu_stream) { }
dev_ctx.Wait();
} }
} }
}; };
...@@ -189,10 +149,9 @@ class ElementwiseAddGradXPUKernel : public ElemwiseGradKernel<T> { ...@@ -189,10 +149,9 @@ class ElementwiseAddGradXPUKernel : public ElemwiseGradKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(elementwise_add, ops::ElementwiseAddXPUKernel<float>,
ops::ElementwiseAddXPUKernel<paddle::platform::float16>);
REGISTER_OP_XPU_KERNEL( REGISTER_OP_XPU_KERNEL(
elementwise_add, elementwise_add_grad, ops::ElementwiseAddGradXPUKernel<float>,
ops::ElementwiseAddXPUKernel<paddle::platform::XPUDeviceContext, float>); ops::ElementwiseAddGradXPUKernel<paddle::platform::float16>);
REGISTER_OP_XPU_KERNEL(elementwise_add_grad,
ops::ElementwiseAddGradXPUKernel<
paddle::platform::XPUDeviceContext, float>);
#endif #endif
...@@ -19,30 +19,33 @@ limitations under the License. */ ...@@ -19,30 +19,33 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename DeviceContext, typename T> template <typename T>
class ElementwiseDivXPUKernel : public framework::OpKernel<T> { class ElementwiseDivXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
XPUElementwise<T>(ctx, xpu::div<T>); XPUElementwise<T, XPUType>(ctx, xpu::broadcast_div<XPUType>);
} }
}; };
template <typename DeviceContext, typename T> template <typename T>
class ElementwiseDivGradXPUKernel : public ElemwiseGradKernel<T> { class ElementwiseDivGradXPUKernel : public ElemwiseGradKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx); ElemwiseGradKernel<T>::Compute(ctx);
XPUElementwiseGrad<T>(ctx, xpu::div_grad<T>, true); XPUElementwiseGrad<T, XPUType>(ctx, xpu::broadcast_div_grad<XPUType>, true);
} }
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(elementwise_div, ops::ElementwiseDivXPUKernel<float>,
ops::ElementwiseDivXPUKernel<paddle::platform::float16>);
REGISTER_OP_XPU_KERNEL( REGISTER_OP_XPU_KERNEL(
elementwise_div, elementwise_div_grad, ops::ElementwiseDivGradXPUKernel<float>,
ops::ElementwiseDivXPUKernel<paddle::platform::XPUDeviceContext, float>); ops::ElementwiseDivGradXPUKernel<paddle::platform::float16>);
REGISTER_OP_XPU_KERNEL(elementwise_div_grad,
ops::ElementwiseDivGradXPUKernel<
paddle::platform::XPUDeviceContext, float>);
#endif #endif
...@@ -21,17 +21,22 @@ namespace operators { ...@@ -21,17 +21,22 @@ namespace operators {
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class ElementwiseFloordivXPUKernel : public framework::OpKernel<T> { class ElementwiseFloordivXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
XPUElementwise<T>(ctx, xpu::floordiv<T>); XPUElementwise<T, XPUType>(ctx, xpu::broadcast_floordiv<XPUType>);
} }
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(elementwise_floordiv, REGISTER_OP_XPU_KERNEL(
ops::ElementwiseFloordivXPUKernel< elementwise_floordiv,
paddle::platform::XPUDeviceContext, float>); ops::ElementwiseFloordivXPUKernel<paddle::platform::XPUDeviceContext,
float>,
ops::ElementwiseFloordivXPUKernel<paddle::platform::XPUDeviceContext,
paddle::platform::float16>);
#endif #endif
...@@ -20,20 +20,24 @@ limitations under the License. */ ...@@ -20,20 +20,24 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename DeviceContext, typename T> template <typename T>
class ElementwiseMaxXPUKernel : public framework::OpKernel<T> { class ElementwiseMaxXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
XPUElementwise<T>(ctx, xpu::max<T>); XPUElementwise<T, XPUType>(ctx, xpu::broadcast_max<XPUType>);
} }
}; };
template <typename DeviceContext, typename T> template <typename T>
class ElementwiseMaxGradXPUKernel : public ElemwiseGradKernel<T> { class ElementwiseMaxGradXPUKernel : public ElemwiseGradKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx); ElemwiseGradKernel<T>::Compute(ctx);
XPUElementwiseGrad<T>(ctx, xpu::max_grad<T>, true); XPUElementwiseGrad<T, XPUType>(ctx, xpu::broadcast_max_grad<XPUType>, true);
} }
}; };
...@@ -41,10 +45,9 @@ class ElementwiseMaxGradXPUKernel : public ElemwiseGradKernel<T> { ...@@ -41,10 +45,9 @@ class ElementwiseMaxGradXPUKernel : public ElemwiseGradKernel<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(elementwise_max, ops::ElementwiseMaxXPUKernel<float>,
ops::ElementwiseMaxXPUKernel<paddle::platform::float16>);
REGISTER_OP_XPU_KERNEL( REGISTER_OP_XPU_KERNEL(
elementwise_max, elementwise_max_grad, ops::ElementwiseMaxGradXPUKernel<float>,
ops::ElementwiseMaxXPUKernel<paddle::platform::XPUDeviceContext, float>); ops::ElementwiseMaxGradXPUKernel<paddle::platform::float16>);
REGISTER_OP_XPU_KERNEL(elementwise_max_grad,
ops::ElementwiseMaxGradXPUKernel<
paddle::platform::XPUDeviceContext, float>);
#endif #endif
...@@ -20,20 +20,24 @@ limitations under the License. */ ...@@ -20,20 +20,24 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename DeviceContext, typename T> template <typename T>
class ElementwiseMinXPUKernel : public framework::OpKernel<T> { class ElementwiseMinXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
XPUElementwise<T>(ctx, xpu::min<T>); XPUElementwise<T, XPUType>(ctx, xpu::broadcast_min<XPUType>);
} }
}; };
template <typename DeviceContext, typename T> template <typename T>
class ElementwiseMinGradXPUKernel : public ElemwiseGradKernel<T> { class ElementwiseMinGradXPUKernel : public ElemwiseGradKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx); ElemwiseGradKernel<T>::Compute(ctx);
XPUElementwiseGrad<T>(ctx, xpu::min_grad<T>, true); XPUElementwiseGrad<T, XPUType>(ctx, xpu::broadcast_min_grad<XPUType>, true);
} }
}; };
...@@ -41,10 +45,9 @@ class ElementwiseMinGradXPUKernel : public ElemwiseGradKernel<T> { ...@@ -41,10 +45,9 @@ class ElementwiseMinGradXPUKernel : public ElemwiseGradKernel<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(elementwise_min, ops::ElementwiseMinXPUKernel<float>,
ops::ElementwiseMinXPUKernel<paddle::platform::float16>);
REGISTER_OP_XPU_KERNEL( REGISTER_OP_XPU_KERNEL(
elementwise_min, elementwise_min_grad, ops::ElementwiseMinGradXPUKernel<float>,
ops::ElementwiseMinXPUKernel<paddle::platform::XPUDeviceContext, float>); ops::ElementwiseMinGradXPUKernel<paddle::platform::float16>);
REGISTER_OP_XPU_KERNEL(elementwise_min_grad,
ops::ElementwiseMinGradXPUKernel<
paddle::platform::XPUDeviceContext, float>);
#endif #endif
...@@ -18,20 +18,25 @@ limitations under the License. */ ...@@ -18,20 +18,25 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_xpu.h" #include "paddle/fluid/operators/elementwise/elementwise_xpu.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename DeviceContext, typename T>
template <typename T>
class ElementwiseMulXPUKernel : public framework::OpKernel<T> { class ElementwiseMulXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
XPUElementwise<T>(ctx, xpu::mul<T>); XPUElementwise<T, XPUType>(ctx, xpu::broadcast_mul<XPUType>);
} }
}; };
// DEFINE_XPU_GRAD_KERNEL(Mul, mul, true);
template <typename DeviceContext, typename T> template <typename T>
class ElementwiseMulGradXPUKernel : public ElemwiseGradKernel<T> { class ElementwiseMulGradXPUKernel : public ElemwiseGradKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx); ElemwiseGradKernel<T>::Compute(ctx);
XPUElementwiseGrad<T>(ctx, xpu::mul_grad<T>, true); XPUElementwiseGrad<T, XPUType>(ctx, xpu::broadcast_mul_grad<XPUType>, true);
} }
}; };
...@@ -39,11 +44,10 @@ class ElementwiseMulGradXPUKernel : public ElemwiseGradKernel<T> { ...@@ -39,11 +44,10 @@ class ElementwiseMulGradXPUKernel : public ElemwiseGradKernel<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(elementwise_mul, ops::ElementwiseMulXPUKernel<float>,
ops::ElementwiseMulXPUKernel<paddle::platform::float16>);
REGISTER_OP_XPU_KERNEL( REGISTER_OP_XPU_KERNEL(
elementwise_mul, elementwise_mul_grad, ops::ElementwiseMulGradXPUKernel<float>,
ops::ElementwiseMulXPUKernel<paddle::platform::XPUDeviceContext, float>); ops::ElementwiseMulGradXPUKernel<paddle::platform::float16>);
REGISTER_OP_XPU_KERNEL(elementwise_mul_grad,
ops::ElementwiseMulGradXPUKernel<
paddle::platform::XPUDeviceContext, float>);
#endif #endif
...@@ -23,9 +23,11 @@ namespace operators { ...@@ -23,9 +23,11 @@ namespace operators {
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class ElementwisePowXPUKernel : public framework::OpKernel<T> { class ElementwisePowXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
XPUElementwise<T>(ctx, xpu::pow<float>); XPUElementwise<T, XPUType>(ctx, xpu::broadcast_pow<XPUType>);
} }
}; };
...@@ -35,6 +37,8 @@ class ElementwisePowXPUKernel : public framework::OpKernel<T> { ...@@ -35,6 +37,8 @@ class ElementwisePowXPUKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL( REGISTER_OP_XPU_KERNEL(
elementwise_pow, elementwise_pow,
ops::ElementwisePowXPUKernel<paddle::platform::XPUDeviceContext, float>); ops::ElementwisePowXPUKernel<paddle::platform::XPUDeviceContext, float>,
ops::ElementwisePowXPUKernel<paddle::platform::XPUDeviceContext,
paddle::platform::float16>);
#endif #endif
...@@ -21,20 +21,25 @@ limitations under the License. */ ...@@ -21,20 +21,25 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename DeviceContext, typename T> template <typename T>
class ElementwiseSubXPUKernel : public framework::OpKernel<T> { class ElementwiseSubXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
XPUElementwise<T>(ctx, xpu::sub<float>); XPUElementwise<T, XPUType>(ctx, xpu::broadcast_sub<XPUType>);
} }
}; };
template <typename DeviceContext, typename T> template <typename T>
class ElementwiseSubGradXPUKernel : public ElemwiseGradKernel<T> { class ElementwiseSubGradXPUKernel : public ElemwiseGradKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx); ElemwiseGradKernel<T>::Compute(ctx);
XPUElementwiseGrad<T>(ctx, xpu::sub_grad<float>, false); XPUElementwiseGrad<T, XPUType>(ctx, xpu::broadcast_sub_grad<XPUType>,
false);
} }
}; };
...@@ -42,11 +47,10 @@ class ElementwiseSubGradXPUKernel : public ElemwiseGradKernel<T> { ...@@ -42,11 +47,10 @@ class ElementwiseSubGradXPUKernel : public ElemwiseGradKernel<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(elementwise_sub, ops::ElementwiseSubXPUKernel<float>,
ops::ElementwiseSubXPUKernel<paddle::platform::float16>);
REGISTER_OP_XPU_KERNEL( REGISTER_OP_XPU_KERNEL(
elementwise_sub, elementwise_sub_grad, ops::ElementwiseSubGradXPUKernel<float>,
ops::ElementwiseSubXPUKernel<paddle::platform::XPUDeviceContext, float>); ops::ElementwiseSubGradXPUKernel<paddle::platform::float16>);
REGISTER_OP_XPU_KERNEL(elementwise_sub_grad,
ops::ElementwiseSubGradXPUKernel<
paddle::platform::XPUDeviceContext, float>);
#endif #endif
...@@ -25,64 +25,12 @@ limitations under the License. */ ...@@ -25,64 +25,12 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
static std::pair<std::vector<int>, std::vector<int>> XPUDimsToBroadcastVector( template <typename T, typename XPUType>
const framework::DDim& x, const framework::DDim& y) {
std::vector<int> x_v;
std::vector<int> y_v;
int y_size = y.size();
for (int i = 0; i < y_size; ++i) {
if (x[i] == y[i]) {
x_v.push_back(y[i]);
y_v.push_back(y[i]);
continue;
}
x_v.push_back(1);
x_v.push_back(x[i]);
y_v.push_back(y[i] / x[i]);
y_v.push_back(x[i]);
}
return std::make_pair(x_v, y_v);
}
static std::pair<std::vector<int>, std::vector<int>> XPUReducesAxisVector(
const framework::DDim& x, const framework::DDim& y) {
std::vector<int> x_vector;
std::vector<int> axis_v;
PADDLE_ENFORCE_GT(
x.size(), 0, platform::errors::OutOfRange("x size is less 1, x shape is ",
x.to_str()));
PADDLE_ENFORCE_GT(
y.size(), 0, platform::errors::OutOfRange("y size is less 1, y shape is ",
y.to_str()));
int y_nums = framework::product(y);
x_vector = framework::vectorize<int>(x);
if (y_nums == 1) {
for (int i = 0; i < x.size(); ++i) {
axis_v.push_back(i);
}
return std::make_pair(x_vector, axis_v);
}
int yidx = 0;
for (size_t i = 0; i < x_vector.size(); ++i) {
if (yidx >= y.size() || y[yidx] == 1) {
axis_v.push_back(i);
yidx++;
continue;
}
if (x_vector[i] != y[yidx]) {
axis_v.push_back(i);
continue;
}
yidx++;
}
return std::make_pair(x_vector, axis_v);
}
template <typename T>
void XPUElementwise( void XPUElementwise(
const framework::ExecutionContext& ctx, const framework::ExecutionContext& ctx,
std::function<int(xpu::Context*, const T*, const T*, T*, int)> func) { std::function<int(xpu::Context*, const XPUType*, const XPUType*, XPUType*,
const std::vector<int>&, const std::vector<int>&)>
func) {
auto x_var = ctx.InputVar("X"); auto x_var = ctx.InputVar("X");
PADDLE_ENFORCE_NE(x_var, nullptr, platform::errors::InvalidArgument( PADDLE_ENFORCE_NE(x_var, nullptr, platform::errors::InvalidArgument(
"Cannot get input Variable X")); "Cannot get input Variable X"));
...@@ -110,86 +58,59 @@ void XPUElementwise( ...@@ -110,86 +58,59 @@ void XPUElementwise(
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.", "Axis should be less than %d, but received axis is %d.",
max_dim, axis)); max_dim, axis));
std::vector<int> x_dims_vec(max_dim, 1);
std::vector<int> x_dims_array(max_dim); std::vector<int> y_dims_vec(max_dim, 1);
std::vector<int> y_dims_array(max_dim); if (x_dims.size() == max_dim) {
std::vector<int> out_dims_array(max_dim); for (int i = 0; i < max_dim; i++) {
GetBroadcastDimsArrays(x_dims, y_dims, x_dims_array.data(), x_dims_vec[i] = x_dims[i];
y_dims_array.data(), out_dims_array.data(), max_dim, }
axis); } else {
framework::DDim out_dim = framework::make_ddim(out_dims_array); for (int i = 0; i < x_dims.size(); i++) {
x_dims_vec[i + axis] = x_dims[i];
}
}
if (y_dims.size() == max_dim) {
for (int i = 0; i < max_dim; i++) {
y_dims_vec[i] = y_dims[i];
}
} else {
for (int i = 0; i < y_dims.size(); i++) {
y_dims_vec[i + axis] = y_dims[i];
}
}
const T* x_data = x.data<T>(); const T* x_data = x.data<T>();
const T* y_data = y->data<T>(); const T* y_data = y->data<T>();
T* z_data = z->data<T>(); T* z_data = z->data<T>();
bool need_wait = false;
framework::Tensor x_broadcast_tensor;
framework::Tensor y_broadcast_tensor;
auto& dev_ctx = auto& dev_ctx =
ctx.template device_context<paddle::platform::XPUDeviceContext>(); ctx.template device_context<paddle::platform::XPUDeviceContext>();
int ret = xpu::SUCCESS;
// begin broadcast now
if (x.numel() != z->numel()) {
// broadcast x
std::pair<std::vector<int>, std::vector<int>> bcast_v =
XPUDimsToBroadcastVector(framework::make_ddim(x_dims_array), out_dim);
ret = xpu::broadcast<T>(dev_ctx.x_context(), x_data, int ret = xpu::SUCCESS;
x_broadcast_tensor.mutable_data<T>(
ctx.GetPlace(), z->numel() * sizeof(T)),
bcast_v.first, bcast_v.second);
PADDLE_ENFORCE_EQ(
ret, xpu::SUCCESS,
platform::errors::External(
"XPU kernel broadcast occur error in XPUElementwise error code %d",
ret));
need_wait = true;
x_data = x_broadcast_tensor.data<T>();
}
if (y->numel() != z->numel()) { ret = func(dev_ctx.x_context(), reinterpret_cast<const XPUType*>(x_data),
// broadcast y reinterpret_cast<const XPUType*>(y_data),
std::vector<int> bcast_x_v; reinterpret_cast<XPUType*>(z_data), x_dims_vec, y_dims_vec);
std::vector<int> bcast_y_v;
std::pair<std::vector<int>, std::vector<int>> bcast_v =
XPUDimsToBroadcastVector(framework::make_ddim(y_dims_array), out_dim);
ret = xpu::broadcast<T>(dev_ctx.x_context(), y_data,
y_broadcast_tensor.mutable_data<T>(
ctx.GetPlace(), z->numel() * sizeof(T)),
bcast_v.first, bcast_v.second);
PADDLE_ENFORCE_EQ(
ret, xpu::SUCCESS,
platform::errors::External(
"XPU kernel broadcast occur error in XPUElementwise error code %d",
ret));
need_wait = true;
y_data = y_broadcast_tensor.data<T>();
}
int len = z->numel();
ret = func(dev_ctx.x_context(), x_data, y_data, z_data, len);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
ret, xpu::SUCCESS, ret, xpu::SUCCESS,
platform::errors::External( platform::errors::External(
"XPU kernel Elementwise occur error in XPUElementwise error code ", "XPU kernel Elementwise occur error in XPUElementwise error code ",
ret)); ret, XPUAPIErrorMsg[ret]));
if (need_wait && dev_ctx.x_context()->xpu_stream) {
dev_ctx.Wait();
}
} }
template <typename T> template <typename T, typename XPUType>
void XPUElementwiseGrad(const framework::ExecutionContext& ctx, void XPUElementwiseGrad(
std::function<int(xpu::Context*, const T*, const T*, const framework::ExecutionContext& ctx,
const T*, const T*, T*, T*, int len)> std::function<int(xpu::Context*, const XPUType*, const XPUType*,
const XPUType*, const XPUType*, XPUType*, XPUType*,
const std::vector<int>&, const std::vector<int>&)>
func, func,
bool use_x_y_data) { bool use_x_y_data) {
auto* x = ctx.Input<framework::Tensor>("X"); auto* x = ctx.Input<framework::Tensor>("X");
auto* y = ctx.Input<framework::Tensor>("Y"); auto* y = ctx.Input<framework::Tensor>("Y");
auto* dz = ctx.Input<framework::Tensor>(framework::GradVarName("Out")); auto* dz = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* z = dz;
auto* dx = ctx.Output<framework::Tensor>(framework::GradVarName("X")); auto* dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<framework::Tensor>(framework::GradVarName("Y")); auto* dy = ctx.Output<framework::Tensor>(framework::GradVarName("Y"));
auto* z = dz;
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
const framework::DDim& x_dims = x->dims(); const framework::DDim& x_dims = x->dims();
const framework::DDim& y_dims = y->dims(); const framework::DDim& y_dims = y->dims();
...@@ -204,120 +125,55 @@ void XPUElementwiseGrad(const framework::ExecutionContext& ctx, ...@@ -204,120 +125,55 @@ void XPUElementwiseGrad(const framework::ExecutionContext& ctx,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.", "Axis should be less than %d, but received axis is %d.",
max_dim, axis)); max_dim, axis));
std::vector<int> x_dims_vec(max_dim, 1);
std::vector<int> y_dims_vec(max_dim, 1);
if (x_dims.size() == max_dim) {
for (int i = 0; i < max_dim; i++) {
x_dims_vec[i] = x_dims[i];
}
} else {
for (int i = 0; i < x_dims.size(); i++) {
x_dims_vec[i + axis] = x_dims[i];
}
}
if (y_dims.size() == max_dim) {
for (int i = 0; i < max_dim; i++) {
y_dims_vec[i] = y_dims[i];
}
} else {
for (int i = 0; i < y_dims.size(); i++) {
y_dims_vec[i + axis] = y_dims[i];
}
}
std::vector<int> x_dims_array(max_dim);
std::vector<int> y_dims_array(max_dim);
std::vector<int> out_dims_array(max_dim);
GetBroadcastDimsArrays(x_dims, y_dims, x_dims_array.data(),
y_dims_array.data(), out_dims_array.data(), max_dim,
axis);
framework::DDim out_dim = framework::make_ddim(out_dims_array);
int len = framework::product(out_dim);
framework::Tensor x_broadcast_tensor;
framework::Tensor y_broadcast_tensor;
framework::Tensor dx_local_tensor;
framework::Tensor dy_local_tensor;
bool need_wait = false;
const T* x_data = use_x_y_data ? x->data<T>() : z->data<T>(); const T* x_data = use_x_y_data ? x->data<T>() : z->data<T>();
const T* y_data = use_x_y_data ? y->data<T>() : z->data<T>(); const T* y_data = use_x_y_data ? y->data<T>() : z->data<T>();
const T* z_data = z->data<T>(); const T* z_data = z->data<T>();
const T* dz_data = (const T*)dz->data<T>();
bool dx_need_reduce = (dx != nullptr) && (dx->numel() != len);
bool dy_need_reduce = (dy != nullptr) && (dy->numel() != len);
T* dx_data = const T* dz_data = dz->data<T>();
((dx == nullptr) || dx_need_reduce) T* dx_data = nullptr;
? (dx_local_tensor.mutable_data<T>(ctx.GetPlace(), len * sizeof(T))) T* dy_data = nullptr;
: (dx->mutable_data<T>(ctx.GetPlace()));
T* dy_data =
((dy == nullptr) || dy_need_reduce)
? (dy_local_tensor.mutable_data<T>(ctx.GetPlace(), len * sizeof(T)))
: (dy->mutable_data<T>(ctx.GetPlace()));
int ret = xpu::SUCCESS;
auto& dev_ctx = auto& dev_ctx =
ctx.template device_context<paddle::platform::XPUDeviceContext>(); ctx.template device_context<paddle::platform::XPUDeviceContext>();
if (use_x_y_data && x->numel() != len) { if (dx) {
std::vector<int> bcast_x_v; dx_data = dx->mutable_data<T>(ctx.GetPlace());
std::vector<int> bcast_y_v;
std::pair<std::vector<int>, std::vector<int>> bcast_v =
XPUDimsToBroadcastVector(framework::make_ddim(x_dims_array), out_dim);
ret = xpu::broadcast<T>(
dev_ctx.x_context(), x_data,
x_broadcast_tensor.mutable_data<T>(ctx.GetPlace(), len * sizeof(T)),
bcast_v.first, bcast_v.second);
PADDLE_ENFORCE_EQ(ret, xpu::SUCCESS,
platform::errors::External(
"XPU kernel broadcast error occur! %d", ret));
need_wait = true;
x_data = x_broadcast_tensor.data<T>();
}
if (use_x_y_data && y->numel() != len) {
// broadcast y
std::vector<int> bcast_x_v;
std::vector<int> bcast_y_v;
std::pair<std::vector<int>, std::vector<int>> bcast_v =
XPUDimsToBroadcastVector(framework::make_ddim(y_dims_array), out_dim);
ret = xpu::broadcast<T>(
dev_ctx.x_context(), y_data,
y_broadcast_tensor.mutable_data<T>(ctx.GetPlace(), len * sizeof(T)),
bcast_v.first, bcast_v.second);
PADDLE_ENFORCE_EQ(ret, xpu::SUCCESS,
platform::errors::External(
"XPU kernel broadcast error occur! %d", ret));
need_wait = true;
y_data = y_broadcast_tensor.data<T>();
} }
if (dy) {
ret = func(dev_ctx.x_context(), x_data, y_data, z_data, dz_data, dx_data, dy_data = dy->mutable_data<T>(ctx.GetPlace());
dy_data, len);
PADDLE_ENFORCE_EQ(ret, xpu::SUCCESS, platform::errors::External(
"XPU kernel binary occur error in "
"XPUElementwiseGrad, error code %d",
ret));
if (dx_need_reduce) {
const framework::DDim& dx_dims = dx->dims();
std::pair<std::vector<int>, std::vector<int>> reduce_v =
XPUReducesAxisVector(out_dim, dx_dims);
ret = xpu::reduce_sum<T>(dev_ctx.x_context(), dx_data,
dx->mutable_data<T>(ctx.GetPlace()),
reduce_v.first, reduce_v.second);
PADDLE_ENFORCE_EQ(
ret, xpu::SUCCESS,
platform::errors::External("XPU kernel reduce_sum occur error in "
"XPUElementwiseGrad, error code %d",
ret));
need_wait = true;
} }
if (dy_need_reduce) { int ret = func(dev_ctx.x_context(), reinterpret_cast<const XPUType*>(x_data),
const framework::DDim& dy_dims = dy->dims(); reinterpret_cast<const XPUType*>(y_data),
std::pair<std::vector<int>, std::vector<int>> reduce_v = reinterpret_cast<const XPUType*>(z_data),
XPUReducesAxisVector(out_dim, dy_dims); reinterpret_cast<const XPUType*>(dz_data),
ret = xpu::reduce_sum<T>(dev_ctx.x_context(), dy_data, reinterpret_cast<XPUType*>(dy_data),
dy->mutable_data<T>(ctx.GetPlace()), reinterpret_cast<XPUType*>(dx_data), x_dims_vec, y_dims_vec);
reduce_v.first, reduce_v.second);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
ret, xpu::SUCCESS, ret, xpu::SUCCESS,
platform::errors::External("XPU kernel reduce_sum occur error in " platform::errors::External(
"XPUElementwiseGrad, error code %d", "XPU kernel Elementwise occur error in XPUElementwise error code ",
ret)); ret, XPUAPIErrorMsg[ret]));
need_wait = true;
}
if (need_wait && dev_ctx.x_context()->xpu_stream) {
dev_ctx.Wait();
}
} }
} // namespace operators } // namespace operators
......
...@@ -31,6 +31,48 @@ XPUOpMap& get_kl2_ops() { ...@@ -31,6 +31,48 @@ XPUOpMap& get_kl2_ops() {
static XPUOpMap s_xpu2_kernels{ static XPUOpMap s_xpu2_kernels{
{"mul", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), {"mul", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})}, pOpKernelType(vartype::FP16, XPUPlace())})},
{"elementwise_sub",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"elementwise_sub_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"elementwise_add",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"elementwise_add_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"elementwise_div",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"elementwise_div_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"elementwise_pow",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"elementwise_floordiv",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"elementwise_mul",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"elementwise_mul_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"elementwise_max",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"elementwise_max_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"elementwise_min",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"elementwise_min_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
// AddMore // AddMore
}; };
......
...@@ -251,7 +251,10 @@ class TestRMSPropV2(XPUOpTest): ...@@ -251,7 +251,10 @@ class TestRMSPropV2(XPUOpTest):
cost = fluid.layers.square_error_cost(input=y_predict, label=y) cost = fluid.layers.square_error_cost(input=y_predict, label=y)
avg_cost = fluid.layers.mean(cost) avg_cost = fluid.layers.mean(cost)
rms_optimizer = paddle.optimizer.RMSProp(learning_rate=0.1) print(avg_cost.shape)
linear = paddle.nn.Linear(13, 5)
rms_optimizer = paddle.optimizer.RMSProp(
learning_rate=0.1, parameters=linear.parameters())
rms_optimizer.minimize(avg_cost) rms_optimizer.minimize(avg_cost)
fetch_list = [avg_cost] fetch_list = [avg_cost]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册