未验证 提交 70b9f2ac 编写于 作者: T taixiurong 提交者: GitHub

dropout support Seed, fix elementwise_add_grad bug, test=kunlun (#39656)

上级 8363406a
...@@ -32,20 +32,18 @@ class DropoutXPUKernel : public framework::OpKernel<T> { ...@@ -32,20 +32,18 @@ class DropoutXPUKernel : public framework::OpKernel<T> {
context.Attr<std::string>("dropout_implementation"); context.Attr<std::string>("dropout_implementation");
auto& dev_ctx = context.template device_context<DeviceContext>(); auto& dev_ctx = context.template device_context<DeviceContext>();
PADDLE_ENFORCE_EQ(!context.HasInput("Seed"), true, auto* seed =
platform::errors::InvalidArgument( context.HasInput("Seed") ? context.Input<Tensor>("Seed") : nullptr;
("Input(Seed) not supported on XPU")));
int is_upscale = (dropout_implementation == "upscale_in_train"); int is_upscale = (dropout_implementation == "upscale_in_train");
if (!context.Attr<bool>("is_test")) { if (!context.Attr<bool>("is_test")) {
std::random_device rnd; int seed_data = 0;
// int seed = (context.Attr<bool>("fix_seed")) ? if (seed) {
// int(context.Attr<int>("seed")) : (rnd()); seed_data = *(seed->data<int>());
int seed = 0;
if (context.Attr<bool>("fix_seed") == true) {
seed = static_cast<int>(context.Attr<int>("seed"));
} else { } else {
seed = rnd(); seed_data =
context.Attr<bool>("fix_seed") ? context.Attr<int>("seed") : 0;
} }
auto* mask = context.Output<Tensor>("Mask"); auto* mask = context.Output<Tensor>("Mask");
...@@ -55,26 +53,26 @@ class DropoutXPUKernel : public framework::OpKernel<T> { ...@@ -55,26 +53,26 @@ class DropoutXPUKernel : public framework::OpKernel<T> {
int r = xpu::constant(dev_ctx.x_context(), int r = xpu::constant(dev_ctx.x_context(),
reinterpret_cast<XPUTyp*>(y_data), y->numel(), reinterpret_cast<XPUTyp*>(y_data), y->numel(),
XPUTyp(0)); XPUTyp(0));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant "); PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
r = xpu::constant(dev_ctx.x_context(), r = xpu::constant(dev_ctx.x_context(),
reinterpret_cast<XPUTyp*>(mask_data), mask->numel(), reinterpret_cast<XPUTyp*>(mask_data), mask->numel(),
XPUTyp(0)); XPUTyp(0));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant "); PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
return; return;
} }
int r = xpu::dropout(dev_ctx.x_context(), int r = xpu::dropout(dev_ctx.x_context(),
reinterpret_cast<const XPUTyp*>(x->data<T>()), reinterpret_cast<const XPUTyp*>(x->data<T>()),
reinterpret_cast<XPUTyp*>(y->data<T>()), reinterpret_cast<XPUTyp*>(y->data<T>()),
reinterpret_cast<XPUTyp*>(mask_data), seed, reinterpret_cast<XPUTyp*>(mask_data), seed_data,
mask->numel(), is_upscale, dropout_prob); mask->numel(), is_upscale, dropout_prob);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "dropout "); PADDLE_ENFORCE_XDNN_SUCCESS(r, "dropout");
} else { } else {
float scale = float scale =
(is_upscale) ? (1.0) : (static_cast<float>(1.0f - dropout_prob)); (is_upscale) ? (1.0) : (static_cast<float>(1.0f - dropout_prob));
int r = xpu::scale( int r = xpu::scale(
dev_ctx.x_context(), reinterpret_cast<const XPUTyp*>(x_data), dev_ctx.x_context(), reinterpret_cast<const XPUTyp*>(x_data),
reinterpret_cast<XPUTyp*>(y_data), x->numel(), false, scale, 0.0f); reinterpret_cast<XPUTyp*>(y_data), x->numel(), false, scale, 0.0f);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "scale "); PADDLE_ENFORCE_XDNN_SUCCESS(r, "scale");
} }
} }
}; };
...@@ -103,7 +101,7 @@ class DropoutGradXPUKernel : public framework::OpKernel<T> { ...@@ -103,7 +101,7 @@ class DropoutGradXPUKernel : public framework::OpKernel<T> {
reinterpret_cast<const XPUType*>(mask_data), reinterpret_cast<const XPUType*>(mask_data),
reinterpret_cast<XPUType*>(grad_x->data<T>()), reinterpret_cast<XPUType*>(grad_x->data<T>()),
grad_y->numel()); grad_y->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "mul "); PADDLE_ENFORCE_XDNN_SUCCESS(r, "mul");
return; return;
} }
...@@ -117,13 +115,13 @@ class DropoutGradXPUKernel : public framework::OpKernel<T> { ...@@ -117,13 +115,13 @@ class DropoutGradXPUKernel : public framework::OpKernel<T> {
reinterpret_cast<const XPUType*>(mask->data<T>()), reinterpret_cast<const XPUType*>(mask->data<T>()),
reinterpret_cast<XPUType*>(mask_new), mask->numel(), reinterpret_cast<XPUType*>(mask_new), mask->numel(),
false, scale, 0.0f); false, scale, 0.0f);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "scale "); PADDLE_ENFORCE_XDNN_SUCCESS(r, "scale");
r = xpu::mul(dev_ctx.x_context(), r = xpu::mul(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(grad_y->data<T>()), reinterpret_cast<const XPUType*>(grad_y->data<T>()),
reinterpret_cast<const XPUType*>(mask_new), reinterpret_cast<const XPUType*>(mask_new),
reinterpret_cast<XPUType*>(grad_x->data<T>()), reinterpret_cast<XPUType*>(grad_x->data<T>()),
grad_y->numel()); grad_y->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "mul "); PADDLE_ENFORCE_XDNN_SUCCESS(r, "mul");
} else { } else {
int r = int r =
xpu::dropout_grad(dev_ctx.x_context(), xpu::dropout_grad(dev_ctx.x_context(),
...@@ -131,7 +129,7 @@ class DropoutGradXPUKernel : public framework::OpKernel<T> { ...@@ -131,7 +129,7 @@ class DropoutGradXPUKernel : public framework::OpKernel<T> {
reinterpret_cast<const XPUType*>(grad_y->data<T>()), reinterpret_cast<const XPUType*>(grad_y->data<T>()),
reinterpret_cast<XPUType*>(grad_x->data<T>()), reinterpret_cast<XPUType*>(grad_x->data<T>()),
dropout_prob, grad_y->numel()); dropout_prob, grad_y->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "dropout_grad "); PADDLE_ENFORCE_XDNN_SUCCESS(r, "dropout_grad");
} }
} }
}; };
......
...@@ -34,17 +34,6 @@ class ElementwiseAddXPUKernel : public framework::OpKernel<T> { ...@@ -34,17 +34,6 @@ class ElementwiseAddXPUKernel : public framework::OpKernel<T> {
} }
}; };
static std::vector<int> get_rdims(const std::vector<int>& xdims,
const std::vector<int>& ydims) {
std::vector<int> rdims;
for (size_t i = 0; i < xdims.size(); i++) {
if (xdims[i] != ydims[i]) {
rdims.push_back(i);
}
}
return rdims;
}
template <typename T> template <typename T>
class ElementwiseAddGradXPUKernel : public ElemwiseGradKernel<T> { class ElementwiseAddGradXPUKernel : public ElemwiseGradKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type; using XPUType = typename XPUTypeTrait<T>::Type;
...@@ -53,64 +42,19 @@ class ElementwiseAddGradXPUKernel : public ElemwiseGradKernel<T> { ...@@ -53,64 +42,19 @@ class ElementwiseAddGradXPUKernel : public ElemwiseGradKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx); ElemwiseGradKernel<T>::Compute(ctx);
auto* x = ctx.Input<framework::Tensor>("X"); auto* x = ctx.Input<framework::Tensor>("X");
auto* y = ctx.Input<framework::Tensor>("Y");
auto* dz = ctx.Input<framework::Tensor>(framework::GradVarName("Out")); auto* dz = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<framework::Tensor>(framework::GradVarName("X")); auto* dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<framework::Tensor>(framework::GradVarName("Y")); auto* dy = ctx.Output<framework::Tensor>(framework::GradVarName("Y"));
const framework::DDim& x_dims = x->dims();
const framework::DDim& y_dims = y->dims();
const framework::DDim& dz_dims = dz->dims(); const framework::DDim& dz_dims = dz->dims();
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis);
int max_dim = std::max(x_dims.size(), y_dims.size());
PADDLE_ENFORCE_GE(
axis, 0,
platform::errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.",
axis));
PADDLE_ENFORCE_LT(
axis, max_dim,
platform::errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.", max_dim,
axis));
std::vector<int> x_dims_vec(max_dim, 1);
std::vector<int> y_dims_vec(max_dim, 1);
std::vector<int> z_dims_vec(max_dim, 1);
if (x_dims.size() == max_dim) {
for (int i = 0; i < max_dim; i++) {
x_dims_vec[i] = x_dims[i];
}
} else {
for (int i = 0; i < x_dims.size(); i++) {
x_dims_vec[i + axis] = x_dims[i];
}
}
if (y_dims.size() == max_dim) {
for (int i = 0; i < max_dim; i++) {
y_dims_vec[i] = y_dims[i];
}
} else {
for (int i = 0; i < y_dims.size(); i++) {
y_dims_vec[i + axis] = y_dims[i];
}
}
for (int i = 0; i < max_dim; i++) {
z_dims_vec[i] = dz_dims[i];
}
std::vector<int> rdims_for_x;
std::vector<int> rdims_for_y;
rdims_for_x = get_rdims(x_dims_vec, z_dims_vec);
rdims_for_y = get_rdims(y_dims_vec, z_dims_vec);
const T* dz_data = dz->data<T>(); const T* dz_data = dz->data<T>();
auto& dev_ctx = auto& dev_ctx =
ctx.template device_context<paddle::platform::XPUDeviceContext>(); ctx.template device_context<paddle::platform::XPUDeviceContext>();
if (dx != nullptr) { if (dx != nullptr) {
T* dx_data = dx->mutable_data<T>(ctx.GetPlace()); T* dx_data = dx->mutable_data<T>(ctx.GetPlace());
if (rdims_for_x.size() == 0) { if (dx->dims() == dz_dims) {
if (dx_data != dz_data) { if (dx_data != dz_data) {
framework::TensorCopy( framework::TensorCopy(
*dz, ctx.GetPlace(), *dz, ctx.GetPlace(),
...@@ -123,27 +67,31 @@ class ElementwiseAddGradXPUKernel : public ElemwiseGradKernel<T> { ...@@ -123,27 +67,31 @@ class ElementwiseAddGradXPUKernel : public ElemwiseGradKernel<T> {
dx->clear(); dx->clear();
dx->mutable_data<T>(x->dims(), ctx.GetPlace()); dx->mutable_data<T>(x->dims(), ctx.GetPlace());
} }
std::vector<int> reduce_dims = GetReduceDim(dx->dims(), dz_dims, axis);
std::vector<int> dz_vector = framework::vectorize<int>(dz_dims);
int ret = xpu::reduce_sum<XPUType>( int ret = xpu::reduce_sum<XPUType>(
dev_ctx.x_context(), reinterpret_cast<const XPUType*>(dz_data), dev_ctx.x_context(), reinterpret_cast<const XPUType*>(dz_data),
reinterpret_cast<XPUType*>(dx_data), z_dims_vec, rdims_for_x); reinterpret_cast<XPUType*>(dx_data), dz_vector, reduce_dims);
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "reduce_sum "); PADDLE_ENFORCE_XDNN_SUCCESS(ret, "reduce_sum");
} }
} }
if (dy != nullptr) { if (dy != nullptr) {
T* dy_data = dy->mutable_data<T>(ctx.GetPlace()); T* dy_data = dy->mutable_data<T>(ctx.GetPlace());
if (rdims_for_y.size() == 0) { if (dy->dims() == dz_dims) {
if (dy_data != dz_data) { if (dy_data != dz_data) {
framework::TensorCopy( framework::TensorCopy(
*dz, ctx.GetPlace(), *dz, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), dy); ctx.template device_context<platform::DeviceContext>(), dy);
} }
} else { } else {
std::vector<int> reduce_dims = GetReduceDim(dy->dims(), dz_dims, axis);
std::vector<int> dz_vector = framework::vectorize<int>(dz_dims);
int ret = xpu::reduce_sum<XPUType>( int ret = xpu::reduce_sum<XPUType>(
dev_ctx.x_context(), reinterpret_cast<const XPUType*>(dz_data), dev_ctx.x_context(), reinterpret_cast<const XPUType*>(dz_data),
reinterpret_cast<XPUType*>(dy_data), z_dims_vec, rdims_for_y); reinterpret_cast<XPUType*>(dy_data), dz_vector, reduce_dims);
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "reduce_sum "); PADDLE_ENFORCE_XDNN_SUCCESS(ret, "reduce_sum");
} }
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册