未验证 提交 ff96a7d5 编写于 作者: T taixiurong 提交者: GitHub

update elementwise api in kunlun (#35021)

上级 881e55e4
......@@ -35,7 +35,7 @@ ELSE ()
ENDIF()
SET(XPU_BASE_URL_WITHOUT_DATE "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev")
SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20210804")
SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20210818")
SET(XPU_XRE_URL "${XPU_BASE_URL}/${XPU_XRE_DIR_NAME}.tar.gz" CACHE STRING "" FORCE)
SET(XPU_XDNN_URL "${XPU_BASE_URL}/${XPU_XDNN_DIR_NAME}.tar.gz" CACHE STRING "" FORCE)
SET(XPU_XCCL_URL "${XPU_BASE_URL_WITHOUT_DATE}/20210623/${XPU_XCCL_DIR_NAME}.tar.gz" CACHE STRING "" FORCE)
......
......@@ -23,93 +23,45 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
template <typename T>
class ElementwiseAddXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext& ctx) const override {
// XPUElementwise<T>(ctx, xpu::add<T>);
// ToDo(QingshuChen): update this optimization to elementwise_xpu.h
auto x_var = ctx.InputVar("X");
PADDLE_ENFORCE_NE(x_var, nullptr, platform::errors::InvalidArgument(
"Cannot get input Variable X"));
PADDLE_ENFORCE_EQ(
x_var->IsType<framework::LoDTensor>(), true,
platform::errors::InvalidArgument(
"XPU only support LoDTensor, Input(X) is not LoDTensor"));
auto x = x_var->Get<framework::LoDTensor>();
auto* y = ctx.Input<framework::LoDTensor>("Y");
auto* z = ctx.Output<framework::LoDTensor>("Out");
z->mutable_data<T>(ctx.GetPlace());
auto x_dims = x.dims();
auto y_dims = y->dims();
int max_dim = std::max(x_dims.size(), y_dims.size());
int axis = ctx.Attr<int>("axis");
axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis);
XPUElementwise<T, XPUType>(ctx, xpu::broadcast_add<XPUType>);
}
};
PADDLE_ENFORCE_GE(
axis, 0,
platform::errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.",
axis));
PADDLE_ENFORCE_LT(
axis, max_dim,
platform::errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.", max_dim,
axis));
std::vector<int> x_dims_vec(max_dim, 1);
std::vector<int> y_dims_vec(max_dim, 1);
if (x_dims.size() == max_dim) {
for (int i = 0; i < max_dim; i++) {
x_dims_vec[i] = x_dims[i];
}
} else {
for (int i = 0; i < x_dims.size(); i++) {
x_dims_vec[i + axis] = x_dims[i];
}
static std::vector<int> get_rdims(const std::vector<int>& xdims,
const std::vector<int>& ydims) {
std::vector<int> rdims;
for (size_t i = 0; i < xdims.size(); i++) {
if (xdims[i] != ydims[i]) {
rdims.push_back(i);
}
if (y_dims.size() == max_dim) {
for (int i = 0; i < max_dim; i++) {
y_dims_vec[i] = y_dims[i];
}
} else {
for (int i = 0; i < y_dims.size(); i++) {
y_dims_vec[i + axis] = y_dims[i];
}
}
const T* x_data = x.data<T>();
const T* y_data = y->data<T>();
T* z_data = z->data<T>();
auto& dev_ctx =
ctx.template device_context<paddle::platform::XPUDeviceContext>();
int ret = xpu::SUCCESS;
ret = xpu::broadcast_add<T>(dev_ctx.x_context(), x_data, y_data, z_data,
x_dims_vec, y_dims_vec);
PADDLE_ENFORCE_EQ(
ret, xpu::SUCCESS,
platform::errors::External(
"XPU kernel Elementwise occur error in XPUElementwise error code ",
ret, XPUAPIErrorMsg[ret]));
}
};
return rdims;
}
template <typename DeviceContext, typename T>
template <typename T>
class ElementwiseAddGradXPUKernel : public ElemwiseGradKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
// XPUElementwiseGrad<T>(ctx, xpu::add_grad<T>, false);
auto* x = ctx.Input<framework::Tensor>("X");
auto* y = ctx.Input<framework::Tensor>("Y");
auto* dz = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<framework::Tensor>(framework::GradVarName("Y"));
int axis = ctx.Attr<int>("axis");
const framework::DDim& x_dims = x->dims();
const framework::DDim& y_dims = y->dims();
int max_dim = std::max(x_dims.size(), y_dims.size());
const framework::DDim& dz_dims = dz->dims();
int axis = ctx.Attr<int>("axis");
axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis);
int max_dim = std::max(x_dims.size(), y_dims.size());
PADDLE_ENFORCE_GE(
axis, 0,
platform::errors::InvalidArgument(
......@@ -120,66 +72,74 @@ class ElementwiseAddGradXPUKernel : public ElemwiseGradKernel<T> {
platform::errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.", max_dim,
axis));
std::vector<int> x_dims_vec(max_dim, 1);
std::vector<int> y_dims_vec(max_dim, 1);
int x_len = 1;
int y_len = 1;
std::vector<int> z_dims_vec(max_dim, 1);
if (x_dims.size() == max_dim) {
for (int i = 0; i < max_dim; i++) {
x_dims_vec[i] = x_dims[i];
x_len *= x_dims_vec[i];
}
} else {
for (int i = 0; i < x_dims.size(); i++) {
x_dims_vec[i + axis] = x_dims[i];
x_len *= x_dims_vec[i];
}
}
if (y_dims.size() == max_dim) {
for (int i = 0; i < max_dim; i++) {
y_dims_vec[i] = y_dims[i];
y_len *= y_dims_vec[i];
}
} else {
for (int i = 0; i < y_dims.size(); i++) {
y_dims_vec[i + axis] = y_dims[i];
y_len *= y_dims_vec[i];
}
}
const T* dz_data = dz->data<T>();
framework::Tensor dx_local_tensor;
framework::Tensor dy_local_tensor;
bool need_wait = false;
T* dx_data = nullptr;
T* dy_data = nullptr;
if (dx) {
dx_data = dx->mutable_data<T>(ctx.GetPlace());
} else {
dx_data =
dx_local_tensor.mutable_data<T>(ctx.GetPlace(), x_len * sizeof(T));
need_wait = true;
}
if (dy) {
dy_data = dy->mutable_data<T>(ctx.GetPlace());
} else {
dy_data =
dy_local_tensor.mutable_data<T>(ctx.GetPlace(), y_len * sizeof(T));
need_wait = true;
for (int i = 0; i < max_dim; i++) {
z_dims_vec[i] = dz_dims[i];
}
std::vector<int> rdims_for_x;
std::vector<int> rdims_for_y;
rdims_for_x = get_rdims(x_dims_vec, z_dims_vec);
rdims_for_y = get_rdims(y_dims_vec, z_dims_vec);
const T* dz_data = dz->data<T>();
auto& dev_ctx =
ctx.template device_context<paddle::platform::XPUDeviceContext>();
int ret = xpu::broadcast_add_grad<T>(dev_ctx.x_context(), dz_data, dz_data,
dz_data, dz_data, dy_data, dx_data,
x_dims_vec, y_dims_vec);
PADDLE_ENFORCE_EQ(
ret, xpu::SUCCESS,
platform::errors::External(
"XPU kernel Elementwise occur error in XPUElementwise error code ",
ret, XPUAPIErrorMsg[ret]));
if (need_wait && dev_ctx.x_context()->xpu_stream) {
dev_ctx.Wait();
if (dx != nullptr) {
if (rdims_for_x.size() == 0) {
framework::TensorCopy(
*dz, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), dx);
} else {
T* dx_data = dx->mutable_data<T>(ctx.GetPlace());
int ret = xpu::reduce_sum<XPUType>(
dev_ctx.x_context(), reinterpret_cast<const XPUType*>(dz_data),
reinterpret_cast<XPUType*>(dx_data), z_dims_vec, rdims_for_x);
PADDLE_ENFORCE_EQ(
ret, xpu::SUCCESS,
platform::errors::External("XPU kernel reduce_sum occur error in "
"XPUElementwise error code ",
ret, XPUAPIErrorMsg[ret]));
}
}
if (dy != nullptr) {
if (rdims_for_y.size() == 0) {
framework::TensorCopy(
*dz, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), dy);
} else {
T* dy_data = dy->mutable_data<T>(ctx.GetPlace());
int ret = xpu::reduce_sum<XPUType>(
dev_ctx.x_context(), reinterpret_cast<const XPUType*>(dz_data),
reinterpret_cast<XPUType*>(dy_data), z_dims_vec, rdims_for_y);
PADDLE_ENFORCE_EQ(
ret, xpu::SUCCESS,
platform::errors::External("XPU kernel reduce_sum occur error in "
"XPUElementwise error code ",
ret, XPUAPIErrorMsg[ret]));
}
}
}
};
......@@ -189,10 +149,9 @@ class ElementwiseAddGradXPUKernel : public ElemwiseGradKernel<T> {
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(elementwise_add, ops::ElementwiseAddXPUKernel<float>,
ops::ElementwiseAddXPUKernel<paddle::platform::float16>);
REGISTER_OP_XPU_KERNEL(
elementwise_add,
ops::ElementwiseAddXPUKernel<paddle::platform::XPUDeviceContext, float>);
REGISTER_OP_XPU_KERNEL(elementwise_add_grad,
ops::ElementwiseAddGradXPUKernel<
paddle::platform::XPUDeviceContext, float>);
elementwise_add_grad, ops::ElementwiseAddGradXPUKernel<float>,
ops::ElementwiseAddGradXPUKernel<paddle::platform::float16>);
#endif
......@@ -19,30 +19,33 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
template <typename T>
class ElementwiseDivXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext& ctx) const override {
XPUElementwise<T>(ctx, xpu::div<T>);
XPUElementwise<T, XPUType>(ctx, xpu::broadcast_div<XPUType>);
}
};
template <typename DeviceContext, typename T>
template <typename T>
class ElementwiseDivGradXPUKernel : public ElemwiseGradKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
XPUElementwiseGrad<T>(ctx, xpu::div_grad<T>, true);
XPUElementwiseGrad<T, XPUType>(ctx, xpu::broadcast_div_grad<XPUType>, true);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(elementwise_div, ops::ElementwiseDivXPUKernel<float>,
ops::ElementwiseDivXPUKernel<paddle::platform::float16>);
REGISTER_OP_XPU_KERNEL(
elementwise_div,
ops::ElementwiseDivXPUKernel<paddle::platform::XPUDeviceContext, float>);
REGISTER_OP_XPU_KERNEL(elementwise_div_grad,
ops::ElementwiseDivGradXPUKernel<
paddle::platform::XPUDeviceContext, float>);
elementwise_div_grad, ops::ElementwiseDivGradXPUKernel<float>,
ops::ElementwiseDivGradXPUKernel<paddle::platform::float16>);
#endif
......@@ -21,17 +21,22 @@ namespace operators {
template <typename DeviceContext, typename T>
class ElementwiseFloordivXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext& ctx) const override {
XPUElementwise<T>(ctx, xpu::floordiv<T>);
XPUElementwise<T, XPUType>(ctx, xpu::broadcast_floordiv<XPUType>);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(elementwise_floordiv,
ops::ElementwiseFloordivXPUKernel<
paddle::platform::XPUDeviceContext, float>);
REGISTER_OP_XPU_KERNEL(
elementwise_floordiv,
ops::ElementwiseFloordivXPUKernel<paddle::platform::XPUDeviceContext,
float>,
ops::ElementwiseFloordivXPUKernel<paddle::platform::XPUDeviceContext,
paddle::platform::float16>);
#endif
......@@ -20,20 +20,24 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
template <typename T>
class ElementwiseMaxXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext& ctx) const override {
XPUElementwise<T>(ctx, xpu::max<T>);
XPUElementwise<T, XPUType>(ctx, xpu::broadcast_max<XPUType>);
}
};
template <typename DeviceContext, typename T>
template <typename T>
class ElementwiseMaxGradXPUKernel : public ElemwiseGradKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
XPUElementwiseGrad<T>(ctx, xpu::max_grad<T>, true);
XPUElementwiseGrad<T, XPUType>(ctx, xpu::broadcast_max_grad<XPUType>, true);
}
};
......@@ -41,10 +45,9 @@ class ElementwiseMaxGradXPUKernel : public ElemwiseGradKernel<T> {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(elementwise_max, ops::ElementwiseMaxXPUKernel<float>,
ops::ElementwiseMaxXPUKernel<paddle::platform::float16>);
REGISTER_OP_XPU_KERNEL(
elementwise_max,
ops::ElementwiseMaxXPUKernel<paddle::platform::XPUDeviceContext, float>);
REGISTER_OP_XPU_KERNEL(elementwise_max_grad,
ops::ElementwiseMaxGradXPUKernel<
paddle::platform::XPUDeviceContext, float>);
elementwise_max_grad, ops::ElementwiseMaxGradXPUKernel<float>,
ops::ElementwiseMaxGradXPUKernel<paddle::platform::float16>);
#endif
......@@ -20,20 +20,24 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
template <typename T>
class ElementwiseMinXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext& ctx) const override {
XPUElementwise<T>(ctx, xpu::min<T>);
XPUElementwise<T, XPUType>(ctx, xpu::broadcast_min<XPUType>);
}
};
template <typename DeviceContext, typename T>
template <typename T>
class ElementwiseMinGradXPUKernel : public ElemwiseGradKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
XPUElementwiseGrad<T>(ctx, xpu::min_grad<T>, true);
XPUElementwiseGrad<T, XPUType>(ctx, xpu::broadcast_min_grad<XPUType>, true);
}
};
......@@ -41,10 +45,9 @@ class ElementwiseMinGradXPUKernel : public ElemwiseGradKernel<T> {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(elementwise_min, ops::ElementwiseMinXPUKernel<float>,
ops::ElementwiseMinXPUKernel<paddle::platform::float16>);
REGISTER_OP_XPU_KERNEL(
elementwise_min,
ops::ElementwiseMinXPUKernel<paddle::platform::XPUDeviceContext, float>);
REGISTER_OP_XPU_KERNEL(elementwise_min_grad,
ops::ElementwiseMinGradXPUKernel<
paddle::platform::XPUDeviceContext, float>);
elementwise_min_grad, ops::ElementwiseMinGradXPUKernel<float>,
ops::ElementwiseMinGradXPUKernel<paddle::platform::float16>);
#endif
......@@ -18,20 +18,25 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_xpu.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
template <typename T>
class ElementwiseMulXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext& ctx) const override {
XPUElementwise<T>(ctx, xpu::mul<T>);
XPUElementwise<T, XPUType>(ctx, xpu::broadcast_mul<XPUType>);
}
};
// DEFINE_XPU_GRAD_KERNEL(Mul, mul, true);
template <typename DeviceContext, typename T>
template <typename T>
class ElementwiseMulGradXPUKernel : public ElemwiseGradKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
XPUElementwiseGrad<T>(ctx, xpu::mul_grad<T>, true);
XPUElementwiseGrad<T, XPUType>(ctx, xpu::broadcast_mul_grad<XPUType>, true);
}
};
......@@ -39,11 +44,10 @@ class ElementwiseMulGradXPUKernel : public ElemwiseGradKernel<T> {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(elementwise_mul, ops::ElementwiseMulXPUKernel<float>,
ops::ElementwiseMulXPUKernel<paddle::platform::float16>);
REGISTER_OP_XPU_KERNEL(
elementwise_mul,
ops::ElementwiseMulXPUKernel<paddle::platform::XPUDeviceContext, float>);
REGISTER_OP_XPU_KERNEL(elementwise_mul_grad,
ops::ElementwiseMulGradXPUKernel<
paddle::platform::XPUDeviceContext, float>);
elementwise_mul_grad, ops::ElementwiseMulGradXPUKernel<float>,
ops::ElementwiseMulGradXPUKernel<paddle::platform::float16>);
#endif
......@@ -23,9 +23,11 @@ namespace operators {
template <typename DeviceContext, typename T>
class ElementwisePowXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext& ctx) const override {
XPUElementwise<T>(ctx, xpu::pow<float>);
XPUElementwise<T, XPUType>(ctx, xpu::broadcast_pow<XPUType>);
}
};
......@@ -35,6 +37,8 @@ class ElementwisePowXPUKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(
elementwise_pow,
ops::ElementwisePowXPUKernel<paddle::platform::XPUDeviceContext, float>);
ops::ElementwisePowXPUKernel<paddle::platform::XPUDeviceContext, float>,
ops::ElementwisePowXPUKernel<paddle::platform::XPUDeviceContext,
paddle::platform::float16>);
#endif
......@@ -21,20 +21,25 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
template <typename T>
class ElementwiseSubXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext& ctx) const override {
XPUElementwise<T>(ctx, xpu::sub<float>);
XPUElementwise<T, XPUType>(ctx, xpu::broadcast_sub<XPUType>);
}
};
template <typename DeviceContext, typename T>
template <typename T>
class ElementwiseSubGradXPUKernel : public ElemwiseGradKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
XPUElementwiseGrad<T>(ctx, xpu::sub_grad<float>, false);
XPUElementwiseGrad<T, XPUType>(ctx, xpu::broadcast_sub_grad<XPUType>,
false);
}
};
......@@ -42,11 +47,10 @@ class ElementwiseSubGradXPUKernel : public ElemwiseGradKernel<T> {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(elementwise_sub, ops::ElementwiseSubXPUKernel<float>,
ops::ElementwiseSubXPUKernel<paddle::platform::float16>);
REGISTER_OP_XPU_KERNEL(
elementwise_sub,
ops::ElementwiseSubXPUKernel<paddle::platform::XPUDeviceContext, float>);
REGISTER_OP_XPU_KERNEL(elementwise_sub_grad,
ops::ElementwiseSubGradXPUKernel<
paddle::platform::XPUDeviceContext, float>);
elementwise_sub_grad, ops::ElementwiseSubGradXPUKernel<float>,
ops::ElementwiseSubGradXPUKernel<paddle::platform::float16>);
#endif
......@@ -25,64 +25,12 @@ limitations under the License. */
namespace paddle {
namespace operators {
static std::pair<std::vector<int>, std::vector<int>> XPUDimsToBroadcastVector(
const framework::DDim& x, const framework::DDim& y) {
std::vector<int> x_v;
std::vector<int> y_v;
int y_size = y.size();
for (int i = 0; i < y_size; ++i) {
if (x[i] == y[i]) {
x_v.push_back(y[i]);
y_v.push_back(y[i]);
continue;
}
x_v.push_back(1);
x_v.push_back(x[i]);
y_v.push_back(y[i] / x[i]);
y_v.push_back(x[i]);
}
return std::make_pair(x_v, y_v);
}
static std::pair<std::vector<int>, std::vector<int>> XPUReducesAxisVector(
const framework::DDim& x, const framework::DDim& y) {
std::vector<int> x_vector;
std::vector<int> axis_v;
PADDLE_ENFORCE_GT(
x.size(), 0, platform::errors::OutOfRange("x size is less 1, x shape is ",
x.to_str()));
PADDLE_ENFORCE_GT(
y.size(), 0, platform::errors::OutOfRange("y size is less 1, y shape is ",
y.to_str()));
int y_nums = framework::product(y);
x_vector = framework::vectorize<int>(x);
if (y_nums == 1) {
for (int i = 0; i < x.size(); ++i) {
axis_v.push_back(i);
}
return std::make_pair(x_vector, axis_v);
}
int yidx = 0;
for (size_t i = 0; i < x_vector.size(); ++i) {
if (yidx >= y.size() || y[yidx] == 1) {
axis_v.push_back(i);
yidx++;
continue;
}
if (x_vector[i] != y[yidx]) {
axis_v.push_back(i);
continue;
}
yidx++;
}
return std::make_pair(x_vector, axis_v);
}
template <typename T>
template <typename T, typename XPUType>
void XPUElementwise(
const framework::ExecutionContext& ctx,
std::function<int(xpu::Context*, const T*, const T*, T*, int)> func) {
std::function<int(xpu::Context*, const XPUType*, const XPUType*, XPUType*,
const std::vector<int>&, const std::vector<int>&)>
func) {
auto x_var = ctx.InputVar("X");
PADDLE_ENFORCE_NE(x_var, nullptr, platform::errors::InvalidArgument(
"Cannot get input Variable X"));
......@@ -110,86 +58,59 @@ void XPUElementwise(
platform::errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.",
max_dim, axis));
std::vector<int> x_dims_array(max_dim);
std::vector<int> y_dims_array(max_dim);
std::vector<int> out_dims_array(max_dim);
GetBroadcastDimsArrays(x_dims, y_dims, x_dims_array.data(),
y_dims_array.data(), out_dims_array.data(), max_dim,
axis);
framework::DDim out_dim = framework::make_ddim(out_dims_array);
std::vector<int> x_dims_vec(max_dim, 1);
std::vector<int> y_dims_vec(max_dim, 1);
if (x_dims.size() == max_dim) {
for (int i = 0; i < max_dim; i++) {
x_dims_vec[i] = x_dims[i];
}
} else {
for (int i = 0; i < x_dims.size(); i++) {
x_dims_vec[i + axis] = x_dims[i];
}
}
if (y_dims.size() == max_dim) {
for (int i = 0; i < max_dim; i++) {
y_dims_vec[i] = y_dims[i];
}
} else {
for (int i = 0; i < y_dims.size(); i++) {
y_dims_vec[i + axis] = y_dims[i];
}
}
const T* x_data = x.data<T>();
const T* y_data = y->data<T>();
T* z_data = z->data<T>();
bool need_wait = false;
framework::Tensor x_broadcast_tensor;
framework::Tensor y_broadcast_tensor;
auto& dev_ctx =
ctx.template device_context<paddle::platform::XPUDeviceContext>();
int ret = xpu::SUCCESS;
// begin broadcast now
if (x.numel() != z->numel()) {
// broadcast x
std::pair<std::vector<int>, std::vector<int>> bcast_v =
XPUDimsToBroadcastVector(framework::make_ddim(x_dims_array), out_dim);
ret = xpu::broadcast<T>(dev_ctx.x_context(), x_data,
x_broadcast_tensor.mutable_data<T>(
ctx.GetPlace(), z->numel() * sizeof(T)),
bcast_v.first, bcast_v.second);
PADDLE_ENFORCE_EQ(
ret, xpu::SUCCESS,
platform::errors::External(
"XPU kernel broadcast occur error in XPUElementwise error code %d",
ret));
need_wait = true;
x_data = x_broadcast_tensor.data<T>();
}
int ret = xpu::SUCCESS;
if (y->numel() != z->numel()) {
// broadcast y
std::vector<int> bcast_x_v;
std::vector<int> bcast_y_v;
std::pair<std::vector<int>, std::vector<int>> bcast_v =
XPUDimsToBroadcastVector(framework::make_ddim(y_dims_array), out_dim);
ret = xpu::broadcast<T>(dev_ctx.x_context(), y_data,
y_broadcast_tensor.mutable_data<T>(
ctx.GetPlace(), z->numel() * sizeof(T)),
bcast_v.first, bcast_v.second);
PADDLE_ENFORCE_EQ(
ret, xpu::SUCCESS,
platform::errors::External(
"XPU kernel broadcast occur error in XPUElementwise error code %d",
ret));
need_wait = true;
y_data = y_broadcast_tensor.data<T>();
}
int len = z->numel();
ret = func(dev_ctx.x_context(), x_data, y_data, z_data, len);
ret = func(dev_ctx.x_context(), reinterpret_cast<const XPUType*>(x_data),
reinterpret_cast<const XPUType*>(y_data),
reinterpret_cast<XPUType*>(z_data), x_dims_vec, y_dims_vec);
PADDLE_ENFORCE_EQ(
ret, xpu::SUCCESS,
platform::errors::External(
"XPU kernel Elementwise occur error in XPUElementwise error code ",
ret));
if (need_wait && dev_ctx.x_context()->xpu_stream) {
dev_ctx.Wait();
}
ret, XPUAPIErrorMsg[ret]));
}
template <typename T>
void XPUElementwiseGrad(const framework::ExecutionContext& ctx,
std::function<int(xpu::Context*, const T*, const T*,
const T*, const T*, T*, T*, int len)>
func,
bool use_x_y_data) {
template <typename T, typename XPUType>
void XPUElementwiseGrad(
const framework::ExecutionContext& ctx,
std::function<int(xpu::Context*, const XPUType*, const XPUType*,
const XPUType*, const XPUType*, XPUType*, XPUType*,
const std::vector<int>&, const std::vector<int>&)>
func,
bool use_x_y_data) {
auto* x = ctx.Input<framework::Tensor>("X");
auto* y = ctx.Input<framework::Tensor>("Y");
auto* dz = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* z = dz;
auto* dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<framework::Tensor>(framework::GradVarName("Y"));
auto* z = dz;
int axis = ctx.Attr<int>("axis");
const framework::DDim& x_dims = x->dims();
const framework::DDim& y_dims = y->dims();
......@@ -204,120 +125,55 @@ void XPUElementwiseGrad(const framework::ExecutionContext& ctx,
platform::errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.",
max_dim, axis));
std::vector<int> x_dims_vec(max_dim, 1);
std::vector<int> y_dims_vec(max_dim, 1);
if (x_dims.size() == max_dim) {
for (int i = 0; i < max_dim; i++) {
x_dims_vec[i] = x_dims[i];
}
} else {
for (int i = 0; i < x_dims.size(); i++) {
x_dims_vec[i + axis] = x_dims[i];
}
}
if (y_dims.size() == max_dim) {
for (int i = 0; i < max_dim; i++) {
y_dims_vec[i] = y_dims[i];
}
} else {
for (int i = 0; i < y_dims.size(); i++) {
y_dims_vec[i + axis] = y_dims[i];
}
}
std::vector<int> x_dims_array(max_dim);
std::vector<int> y_dims_array(max_dim);
std::vector<int> out_dims_array(max_dim);
GetBroadcastDimsArrays(x_dims, y_dims, x_dims_array.data(),
y_dims_array.data(), out_dims_array.data(), max_dim,
axis);
framework::DDim out_dim = framework::make_ddim(out_dims_array);
int len = framework::product(out_dim);
framework::Tensor x_broadcast_tensor;
framework::Tensor y_broadcast_tensor;
framework::Tensor dx_local_tensor;
framework::Tensor dy_local_tensor;
bool need_wait = false;
const T* x_data = use_x_y_data ? x->data<T>() : z->data<T>();
const T* y_data = use_x_y_data ? y->data<T>() : z->data<T>();
const T* z_data = z->data<T>();
const T* dz_data = (const T*)dz->data<T>();
bool dx_need_reduce = (dx != nullptr) && (dx->numel() != len);
bool dy_need_reduce = (dy != nullptr) && (dy->numel() != len);
T* dx_data =
((dx == nullptr) || dx_need_reduce)
? (dx_local_tensor.mutable_data<T>(ctx.GetPlace(), len * sizeof(T)))
: (dx->mutable_data<T>(ctx.GetPlace()));
T* dy_data =
((dy == nullptr) || dy_need_reduce)
? (dy_local_tensor.mutable_data<T>(ctx.GetPlace(), len * sizeof(T)))
: (dy->mutable_data<T>(ctx.GetPlace()));
int ret = xpu::SUCCESS;
const T* dz_data = dz->data<T>();
T* dx_data = nullptr;
T* dy_data = nullptr;
auto& dev_ctx =
ctx.template device_context<paddle::platform::XPUDeviceContext>();
if (use_x_y_data && x->numel() != len) {
std::vector<int> bcast_x_v;
std::vector<int> bcast_y_v;
std::pair<std::vector<int>, std::vector<int>> bcast_v =
XPUDimsToBroadcastVector(framework::make_ddim(x_dims_array), out_dim);
ret = xpu::broadcast<T>(
dev_ctx.x_context(), x_data,
x_broadcast_tensor.mutable_data<T>(ctx.GetPlace(), len * sizeof(T)),
bcast_v.first, bcast_v.second);
PADDLE_ENFORCE_EQ(ret, xpu::SUCCESS,
platform::errors::External(
"XPU kernel broadcast error occur! %d", ret));
need_wait = true;
x_data = x_broadcast_tensor.data<T>();
}
if (use_x_y_data && y->numel() != len) {
// broadcast y
std::vector<int> bcast_x_v;
std::vector<int> bcast_y_v;
std::pair<std::vector<int>, std::vector<int>> bcast_v =
XPUDimsToBroadcastVector(framework::make_ddim(y_dims_array), out_dim);
ret = xpu::broadcast<T>(
dev_ctx.x_context(), y_data,
y_broadcast_tensor.mutable_data<T>(ctx.GetPlace(), len * sizeof(T)),
bcast_v.first, bcast_v.second);
PADDLE_ENFORCE_EQ(ret, xpu::SUCCESS,
platform::errors::External(
"XPU kernel broadcast error occur! %d", ret));
need_wait = true;
y_data = y_broadcast_tensor.data<T>();
}
ret = func(dev_ctx.x_context(), x_data, y_data, z_data, dz_data, dx_data,
dy_data, len);
PADDLE_ENFORCE_EQ(ret, xpu::SUCCESS, platform::errors::External(
"XPU kernel binary occur error in "
"XPUElementwiseGrad, error code %d",
ret));
if (dx_need_reduce) {
const framework::DDim& dx_dims = dx->dims();
std::pair<std::vector<int>, std::vector<int>> reduce_v =
XPUReducesAxisVector(out_dim, dx_dims);
ret = xpu::reduce_sum<T>(dev_ctx.x_context(), dx_data,
dx->mutable_data<T>(ctx.GetPlace()),
reduce_v.first, reduce_v.second);
PADDLE_ENFORCE_EQ(
ret, xpu::SUCCESS,
platform::errors::External("XPU kernel reduce_sum occur error in "
"XPUElementwiseGrad, error code %d",
ret));
need_wait = true;
if (dx) {
dx_data = dx->mutable_data<T>(ctx.GetPlace());
}
if (dy_need_reduce) {
const framework::DDim& dy_dims = dy->dims();
std::pair<std::vector<int>, std::vector<int>> reduce_v =
XPUReducesAxisVector(out_dim, dy_dims);
ret = xpu::reduce_sum<T>(dev_ctx.x_context(), dy_data,
dy->mutable_data<T>(ctx.GetPlace()),
reduce_v.first, reduce_v.second);
PADDLE_ENFORCE_EQ(
ret, xpu::SUCCESS,
platform::errors::External("XPU kernel reduce_sum occur error in "
"XPUElementwiseGrad, error code %d",
ret));
need_wait = true;
if (dy) {
dy_data = dy->mutable_data<T>(ctx.GetPlace());
}
if (need_wait && dev_ctx.x_context()->xpu_stream) {
dev_ctx.Wait();
}
int ret = func(dev_ctx.x_context(), reinterpret_cast<const XPUType*>(x_data),
reinterpret_cast<const XPUType*>(y_data),
reinterpret_cast<const XPUType*>(z_data),
reinterpret_cast<const XPUType*>(dz_data),
reinterpret_cast<XPUType*>(dy_data),
reinterpret_cast<XPUType*>(dx_data), x_dims_vec, y_dims_vec);
PADDLE_ENFORCE_EQ(
ret, xpu::SUCCESS,
platform::errors::External(
"XPU kernel Elementwise occur error in XPUElementwise error code ",
ret, XPUAPIErrorMsg[ret]));
}
} // namespace operators
......
......@@ -31,6 +31,48 @@ XPUOpMap& get_kl2_ops() {
static XPUOpMap s_xpu2_kernels{
{"mul", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"elementwise_sub",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"elementwise_sub_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"elementwise_add",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"elementwise_add_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"elementwise_div",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"elementwise_div_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"elementwise_pow",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"elementwise_floordiv",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"elementwise_mul",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"elementwise_mul_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"elementwise_max",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"elementwise_max_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"elementwise_min",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"elementwise_min_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
// AddMore
};
......
......@@ -251,7 +251,10 @@ class TestRMSPropV2(XPUOpTest):
cost = fluid.layers.square_error_cost(input=y_predict, label=y)
avg_cost = fluid.layers.mean(cost)
rms_optimizer = paddle.optimizer.RMSProp(learning_rate=0.1)
print(avg_cost.shape)
linear = paddle.nn.Linear(13, 5)
rms_optimizer = paddle.optimizer.RMSProp(
learning_rate=0.1, parameters=linear.parameters())
rms_optimizer.minimize(avg_cost)
fetch_list = [avg_cost]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册