未验证 提交 70b9f2ac 编写于 作者: T taixiurong 提交者: GitHub

dropout support Seed, fix elementwise_add_grad bug, test=kunlun (#39656)

上级 8363406a
......@@ -32,20 +32,18 @@ class DropoutXPUKernel : public framework::OpKernel<T> {
context.Attr<std::string>("dropout_implementation");
auto& dev_ctx = context.template device_context<DeviceContext>();
PADDLE_ENFORCE_EQ(!context.HasInput("Seed"), true,
platform::errors::InvalidArgument(
("Input(Seed) not supported on XPU")));
auto* seed =
context.HasInput("Seed") ? context.Input<Tensor>("Seed") : nullptr;
int is_upscale = (dropout_implementation == "upscale_in_train");
if (!context.Attr<bool>("is_test")) {
std::random_device rnd;
// int seed = (context.Attr<bool>("fix_seed")) ?
// int(context.Attr<int>("seed")) : (rnd());
int seed = 0;
if (context.Attr<bool>("fix_seed") == true) {
seed = static_cast<int>(context.Attr<int>("seed"));
int seed_data = 0;
if (seed) {
seed_data = *(seed->data<int>());
} else {
seed = rnd();
seed_data =
context.Attr<bool>("fix_seed") ? context.Attr<int>("seed") : 0;
}
auto* mask = context.Output<Tensor>("Mask");
......@@ -55,26 +53,26 @@ class DropoutXPUKernel : public framework::OpKernel<T> {
int r = xpu::constant(dev_ctx.x_context(),
reinterpret_cast<XPUTyp*>(y_data), y->numel(),
XPUTyp(0));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant ");
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
r = xpu::constant(dev_ctx.x_context(),
reinterpret_cast<XPUTyp*>(mask_data), mask->numel(),
XPUTyp(0));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant ");
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
return;
}
int r = xpu::dropout(dev_ctx.x_context(),
reinterpret_cast<const XPUTyp*>(x->data<T>()),
reinterpret_cast<XPUTyp*>(y->data<T>()),
reinterpret_cast<XPUTyp*>(mask_data), seed,
reinterpret_cast<XPUTyp*>(mask_data), seed_data,
mask->numel(), is_upscale, dropout_prob);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "dropout ");
PADDLE_ENFORCE_XDNN_SUCCESS(r, "dropout");
} else {
float scale =
(is_upscale) ? (1.0) : (static_cast<float>(1.0f - dropout_prob));
int r = xpu::scale(
dev_ctx.x_context(), reinterpret_cast<const XPUTyp*>(x_data),
reinterpret_cast<XPUTyp*>(y_data), x->numel(), false, scale, 0.0f);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "scale ");
PADDLE_ENFORCE_XDNN_SUCCESS(r, "scale");
}
}
};
......@@ -103,7 +101,7 @@ class DropoutGradXPUKernel : public framework::OpKernel<T> {
reinterpret_cast<const XPUType*>(mask_data),
reinterpret_cast<XPUType*>(grad_x->data<T>()),
grad_y->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "mul ");
PADDLE_ENFORCE_XDNN_SUCCESS(r, "mul");
return;
}
......@@ -117,13 +115,13 @@ class DropoutGradXPUKernel : public framework::OpKernel<T> {
reinterpret_cast<const XPUType*>(mask->data<T>()),
reinterpret_cast<XPUType*>(mask_new), mask->numel(),
false, scale, 0.0f);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "scale ");
PADDLE_ENFORCE_XDNN_SUCCESS(r, "scale");
r = xpu::mul(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(grad_y->data<T>()),
reinterpret_cast<const XPUType*>(mask_new),
reinterpret_cast<XPUType*>(grad_x->data<T>()),
grad_y->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "mul ");
PADDLE_ENFORCE_XDNN_SUCCESS(r, "mul");
} else {
int r =
xpu::dropout_grad(dev_ctx.x_context(),
......@@ -131,7 +129,7 @@ class DropoutGradXPUKernel : public framework::OpKernel<T> {
reinterpret_cast<const XPUType*>(grad_y->data<T>()),
reinterpret_cast<XPUType*>(grad_x->data<T>()),
dropout_prob, grad_y->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "dropout_grad ");
PADDLE_ENFORCE_XDNN_SUCCESS(r, "dropout_grad");
}
}
};
......
......@@ -34,17 +34,6 @@ class ElementwiseAddXPUKernel : public framework::OpKernel<T> {
}
};
static std::vector<int> get_rdims(const std::vector<int>& xdims,
const std::vector<int>& ydims) {
std::vector<int> rdims;
for (size_t i = 0; i < xdims.size(); i++) {
if (xdims[i] != ydims[i]) {
rdims.push_back(i);
}
}
return rdims;
}
template <typename T>
class ElementwiseAddGradXPUKernel : public ElemwiseGradKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
......@@ -53,64 +42,19 @@ class ElementwiseAddGradXPUKernel : public ElemwiseGradKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
auto* x = ctx.Input<framework::Tensor>("X");
auto* y = ctx.Input<framework::Tensor>("Y");
auto* dz = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<framework::Tensor>(framework::GradVarName("Y"));
const framework::DDim& x_dims = x->dims();
const framework::DDim& y_dims = y->dims();
const framework::DDim& dz_dims = dz->dims();
int axis = ctx.Attr<int>("axis");
axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis);
int max_dim = std::max(x_dims.size(), y_dims.size());
PADDLE_ENFORCE_GE(
axis, 0,
platform::errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.",
axis));
PADDLE_ENFORCE_LT(
axis, max_dim,
platform::errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.", max_dim,
axis));
std::vector<int> x_dims_vec(max_dim, 1);
std::vector<int> y_dims_vec(max_dim, 1);
std::vector<int> z_dims_vec(max_dim, 1);
if (x_dims.size() == max_dim) {
for (int i = 0; i < max_dim; i++) {
x_dims_vec[i] = x_dims[i];
}
} else {
for (int i = 0; i < x_dims.size(); i++) {
x_dims_vec[i + axis] = x_dims[i];
}
}
if (y_dims.size() == max_dim) {
for (int i = 0; i < max_dim; i++) {
y_dims_vec[i] = y_dims[i];
}
} else {
for (int i = 0; i < y_dims.size(); i++) {
y_dims_vec[i + axis] = y_dims[i];
}
}
for (int i = 0; i < max_dim; i++) {
z_dims_vec[i] = dz_dims[i];
}
std::vector<int> rdims_for_x;
std::vector<int> rdims_for_y;
rdims_for_x = get_rdims(x_dims_vec, z_dims_vec);
rdims_for_y = get_rdims(y_dims_vec, z_dims_vec);
const T* dz_data = dz->data<T>();
auto& dev_ctx =
ctx.template device_context<paddle::platform::XPUDeviceContext>();
if (dx != nullptr) {
T* dx_data = dx->mutable_data<T>(ctx.GetPlace());
if (rdims_for_x.size() == 0) {
if (dx->dims() == dz_dims) {
if (dx_data != dz_data) {
framework::TensorCopy(
*dz, ctx.GetPlace(),
......@@ -123,27 +67,31 @@ class ElementwiseAddGradXPUKernel : public ElemwiseGradKernel<T> {
dx->clear();
dx->mutable_data<T>(x->dims(), ctx.GetPlace());
}
std::vector<int> reduce_dims = GetReduceDim(dx->dims(), dz_dims, axis);
std::vector<int> dz_vector = framework::vectorize<int>(dz_dims);
int ret = xpu::reduce_sum<XPUType>(
dev_ctx.x_context(), reinterpret_cast<const XPUType*>(dz_data),
reinterpret_cast<XPUType*>(dx_data), z_dims_vec, rdims_for_x);
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "reduce_sum ");
reinterpret_cast<XPUType*>(dx_data), dz_vector, reduce_dims);
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "reduce_sum");
}
}
if (dy != nullptr) {
T* dy_data = dy->mutable_data<T>(ctx.GetPlace());
if (rdims_for_y.size() == 0) {
if (dy->dims() == dz_dims) {
if (dy_data != dz_data) {
framework::TensorCopy(
*dz, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), dy);
}
} else {
std::vector<int> reduce_dims = GetReduceDim(dy->dims(), dz_dims, axis);
std::vector<int> dz_vector = framework::vectorize<int>(dz_dims);
int ret = xpu::reduce_sum<XPUType>(
dev_ctx.x_context(), reinterpret_cast<const XPUType*>(dz_data),
reinterpret_cast<XPUType*>(dy_data), z_dims_vec, rdims_for_y);
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "reduce_sum ");
reinterpret_cast<XPUType*>(dy_data), dz_vector, reduce_dims);
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "reduce_sum");
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册