未验证 提交 b207b8a7 编写于 作者: W wangchaochaohu 提交者: GitHub

[cherry-pick]memory optimization for fuse pattern of elemwise_add + act (#30303)

* reduce the  occupied size  of memory for the fused pattern of elementwise_add Op and activation Op(relu Op for example) (#29885)

* register OPMaker and Infer Shape Check for fused_elementwise_add (#30259)
上级 2db79f0a
......@@ -183,7 +183,7 @@ ir::Graph *FuseElewiseAddActPass::FuseElewiseAddActInplaceGrad(
std::string d_ele_y_n = d_ele_y->Name();
OpDesc desc;
desc.SetType("fused_elemwise_activation_grad");
desc.SetType("fused_elemwise_add_activation_grad");
desc.SetInput("IntermediateOut", {});
desc.SetInput("X", {});
desc.SetInput("Y", std::vector<std::string>({ele_y_n}));
......@@ -231,7 +231,7 @@ Node *FuseElewiseAddActPass::CreateFuseElewiseAddActNode(
desc.SetInput("Y", std::vector<std::string>({ele_y_n}));
desc.SetOutput("Out", std::vector<std::string>({act_out_n}));
desc.SetOutput("IntermediateOut", std::vector<std::string>({ele_out_n}));
desc.SetType("fused_elemwise_activation");
desc.SetType("fused_elemwise_add_activation");
desc.SetAttr("save_intermediate_out", true);
desc.SetAttr("functor_list", std::vector<std::string>(
{op_1->Op()->Type(), op_2->Op()->Type()}));
......@@ -251,7 +251,7 @@ void FuseElewiseAddActPass::RemoveIntermediateOut(Graph *graph) const {
std::unordered_set<const Node *> need_removed_nodes;
for (auto &cur_node : graph->Nodes()) {
if (cur_node->IsVar()) continue;
if (cur_node->Name() == "fused_elemwise_activation") {
if (cur_node->Name() == "fused_elemwise_add_activation") {
bool save_intermediate_out = BOOST_GET_CONST(
bool, cur_node->Op()->GetAttr("save_intermediate_out"));
auto intermediate_out_args = cur_node->Op()->Output("IntermediateOut");
......@@ -272,7 +272,7 @@ void FuseElewiseAddActPass::RemoveIntermediateOut(Graph *graph) const {
}
}
}
} else if (cur_node->Name() == "fused_elemwise_activation_grad") {
} else if (cur_node->Name() == "fused_elemwise_add_activation_grad") {
auto intermediate_out_grad_args =
cur_node->Op()->Output(GradVarName("IntermediateOut"));
PADDLE_ENFORCE_EQ(
......
......@@ -2273,8 +2273,9 @@ template <typename T, typename DX_OP, typename DY_OP, typename DIntermediate_OP,
bool UseIntermediateOut>
struct FusedElemwiseAndActGradNoBroadcast {
HOSTDEVICE void operator()(size_t i) {
T x_val = x_[i];
T y_val = y_[i];
T zero = static_cast<T>(0);
T x_val = (x_ == nullptr) ? zero : x_[i];
T y_val = (y_ == nullptr) ? zero : y_[i];
T out_val = out_[i];
T dout_val = dout_[i];
T intermediate_out_val = UseIntermediateOut
......@@ -2320,11 +2321,14 @@ void FusedElemwiseAndActGradComputeNoBroadcast(
size_t N = static_cast<size_t>(framework::product(x_dim));
platform::ForRange<DeviceContext> for_range(
ctx.template device_context<DeviceContext>(), N);
for_range(
FusedElemwiseAndActGradNoBroadcast<T, DX_OP, DY_OP, DIntermediate_OP,
UseIntermediateOut>{
x->data<T>(), y->data<T>(),
intermediate_out ? intermediate_out->data<T>() : nullptr,
const T *x_data = nullptr;
const T *y_data = nullptr;
if (x->IsInitialized()) x_data = x->data<T>();
if (y->IsInitialized()) y_data = y->data<T>();
for_range(FusedElemwiseAndActGradNoBroadcast<
T, DX_OP, DY_OP, DIntermediate_OP, UseIntermediateOut>{
x_data, y_data, intermediate_out ? intermediate_out->data<T>() : nullptr,
out->data<T>(), dout->data<T>(), dx_op, dy_op, dintermediate_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()),
......@@ -2340,6 +2344,7 @@ static void FusedElemwiseAndActGradBroadcast1CPU(
const T *dout, int h, int w, DX_OP dx_op, DY_OP dy_op,
DIntermediate_OP dintermediate_op, T *dx, T *dy, T *d_intermediate) {
int64_t tmp_out_idx, x_idx, y_idx;
T zero = static_cast<T>(0);
for (int i = 0; i < h; ++i) {
for (int j = 0; j < w; ++j) {
int offset = i * w + j;
......@@ -2347,6 +2352,8 @@ static void FusedElemwiseAndActGradBroadcast1CPU(
tmp_out_idx = BcastY ? j : offset;
y_idx = BcastY ? j : offset;
x_idx = BcastY ? offset : j;
T x_val = (x == nullptr) ? zero : x[x_idx];
T y_val = (y == nullptr) ? zero : y[y_idx];
if (SameShapeOfIntermediateOutAndOut) {
tmp_out_idx = offset;
......@@ -2354,11 +2361,10 @@ static void FusedElemwiseAndActGradBroadcast1CPU(
if (dx != nullptr) {
T tmp = UseIntermediateOut
? dx_op.UseIntermediateOut(x[x_idx], y[y_idx],
? dx_op.UseIntermediateOut(x_val, y_val,
intermediate_out[tmp_out_idx],
out[offset], dout[offset])
: dx_op.Recompute(x[x_idx], y[y_idx], out[offset],
dout[offset]);
: dx_op.Recompute(x_val, y_val, out[offset], dout[offset]);
if (BcastY) {
dx[x_idx] = tmp;
......@@ -2372,11 +2378,10 @@ static void FusedElemwiseAndActGradBroadcast1CPU(
}
if (dy != nullptr) {
T tmp = UseIntermediateOut
? dy_op.UseIntermediateOut(x[x_idx], y[y_idx],
? dy_op.UseIntermediateOut(x_val, y_val,
intermediate_out[tmp_out_idx],
out[offset], dout[offset])
: dy_op.Recompute(x[x_idx], y[y_idx], out[offset],
dout[offset]);
: dy_op.Recompute(x_val, y_val, out[offset], dout[offset]);
if (BcastY) {
if (i == 0) {
dy[y_idx] = tmp;
......@@ -2390,10 +2395,10 @@ static void FusedElemwiseAndActGradBroadcast1CPU(
if (d_intermediate != nullptr) {
T tmp = UseIntermediateOut
? dintermediate_op.UseIntermediateOut(
x[x_idx], intermediate_out[tmp_out_idx], out[offset],
x_val, intermediate_out[tmp_out_idx], out[offset],
dout[offset])
: dintermediate_op.Recompute(x[x_idx], y[y_idx],
out[offset], dout[i]);
: dintermediate_op.Recompute(x_val, y_val, out[offset],
dout[i]);
if (SameShapeOfIntermediateOutAndOut) {
d_intermediate[tmp_out_idx] = tmp;
} else {
......@@ -2416,6 +2421,7 @@ static void FusedElemwiseAndActGradBroadcast2CPU(
const T *dout, int pre, int n, int post, DX_OP dx_op, DY_OP dy_op,
DIntermediate_OP dintermediate_op, T *dx, T *dy, T *d_intermediate) {
int64_t tmp_out_idx, x_idx, y_idx;
T zero = static_cast<T>(0);
for (int i = 0; i < pre; ++i) {
for (int j = 0; j < n; ++j) {
for (int k = 0; k < post; ++k) {
......@@ -2425,17 +2431,20 @@ static void FusedElemwiseAndActGradBroadcast2CPU(
y_idx = BcastY ? j : offset;
x_idx = BcastY ? offset : j;
T x_val = (x == nullptr) ? zero : x[x_idx];
T y_val = (y == nullptr) ? zero : y[y_idx];
if (SameShapeOfIntermediateOutAndOut) {
tmp_out_idx = offset;
}
if (dx != nullptr) {
T tmp = UseIntermediateOut
? dx_op.UseIntermediateOut(x[x_idx], y[y_idx],
T tmp =
UseIntermediateOut
? dx_op.UseIntermediateOut(x_val, y_val,
intermediate_out[tmp_out_idx],
out[offset], dout[offset])
: dx_op.Recompute(x[x_idx], y[y_idx], out[offset],
dout[offset]);
: dx_op.Recompute(x_val, y_val, out[offset], dout[offset]);
if (BcastY) {
dx[x_idx] = tmp;
......@@ -2448,12 +2457,12 @@ static void FusedElemwiseAndActGradBroadcast2CPU(
}
}
if (dy != nullptr) {
T tmp = UseIntermediateOut
? dy_op.UseIntermediateOut(x[x_idx], y[y_idx],
T tmp =
UseIntermediateOut
? dy_op.UseIntermediateOut(x_val, y_val,
intermediate_out[tmp_out_idx],
out[offset], dout[offset])
: dy_op.Recompute(x[x_idx], y[y_idx], out[offset],
dout[offset]);
: dy_op.Recompute(x_val, y_val, out[offset], dout[offset]);
if (BcastY) {
if (i == 0 && k == 0) {
dy[y_idx] = tmp;
......@@ -2467,10 +2476,10 @@ static void FusedElemwiseAndActGradBroadcast2CPU(
if (d_intermediate != nullptr) {
T tmp = UseIntermediateOut
? dintermediate_op.UseIntermediateOut(
x[x_idx], intermediate_out[tmp_out_idx],
out[offset], dout[offset])
: dintermediate_op.Recompute(x[x_idx], y[y_idx],
out[offset], dout[i]);
x_val, intermediate_out[tmp_out_idx], out[offset],
dout[offset])
: dintermediate_op.Recompute(x_val, y_val, out[offset],
dout[i]);
if (SameShapeOfIntermediateOutAndOut) {
d_intermediate[tmp_out_idx] = tmp;
} else {
......@@ -2499,6 +2508,7 @@ static __global__ void FusedElemwiseAndActGradBroadcast1CUDAKernel(
int tid = threadIdx.x;
T val(0), inter_val(0);
int64_t tmp_out_idx, x_idx, y_idx;
T zero = static_cast<T>(0);
do {
int offset = i * w + j;
......@@ -2506,18 +2516,19 @@ static __global__ void FusedElemwiseAndActGradBroadcast1CUDAKernel(
tmp_out_idx = BcastY ? j : offset;
y_idx = BcastY ? j : offset;
x_idx = BcastY ? offset : j;
T x_val = (x == nullptr) ? zero : x[x_idx];
T y_val = (y == nullptr) ? zero : y[y_idx];
if (SameShapeOfIntermediateOutAndOut) {
tmp_out_idx = offset;
}
if (dx != nullptr) {
T tmp =
UseIntermediateOut
? dx_op.UseIntermediateOut(x[x_idx], y[y_idx],
T tmp = UseIntermediateOut
? dx_op.UseIntermediateOut(x_val, y_val,
intermediate_out[tmp_out_idx],
out[offset], dout[offset])
: dx_op.Recompute(x[x_idx], y[y_idx], out[offset], dout[offset]);
: dx_op.Recompute(x_val, y_val, out[offset], dout[offset]);
if (BcastY) {
dx[x_idx] = tmp;
......@@ -2526,12 +2537,11 @@ static __global__ void FusedElemwiseAndActGradBroadcast1CUDAKernel(
}
}
if (dy != nullptr) {
T tmp =
UseIntermediateOut
? dy_op.UseIntermediateOut(x[x_idx], y[y_idx],
T tmp = UseIntermediateOut
? dy_op.UseIntermediateOut(x_val, y_val,
intermediate_out[tmp_out_idx],
out[offset], dout[offset])
: dy_op.Recompute(x[x_idx], y[y_idx], out[offset], dout[offset]);
: dy_op.Recompute(x_val, y_val, out[offset], dout[offset]);
if (BcastY) {
val += tmp;
} else {
......@@ -2543,7 +2553,7 @@ static __global__ void FusedElemwiseAndActGradBroadcast1CUDAKernel(
? dintermediate_op.UseIntermediateOut(
y[y_idx], intermediate_out[tmp_out_idx], out[offset],
dout[offset])
: dintermediate_op.Recompute(x[x_idx], y[y_idx], out[offset],
: dintermediate_op.Recompute(x_val, y_val, out[offset],
dout[offset]);
if (SameShapeOfIntermediateOutAndOut) {
d_intermediate[tmp_out_idx] = tmp;
......@@ -2610,6 +2620,7 @@ static __global__ void FusedElemwiseAndActGradBroadcast2CUDAKernel(
T val(0), inter_val(0);
int ttid = tid;
int64_t tmp_out_idx, x_idx, y_idx;
T zero = static_cast<T>(0);
while (true) {
int i = ttid / post;
int k = ttid % post;
......@@ -2620,18 +2631,19 @@ static __global__ void FusedElemwiseAndActGradBroadcast2CUDAKernel(
tmp_out_idx = BcastY ? j : offset;
y_idx = BcastY ? j : offset;
x_idx = BcastY ? offset : j;
T x_val = (x == nullptr) ? zero : x[x_idx];
T y_val = (y == nullptr) ? zero : y[y_idx];
if (SameShapeOfIntermediateOutAndOut) {
tmp_out_idx = offset;
}
if (dx != nullptr) {
T tmp =
UseIntermediateOut
? dx_op.UseIntermediateOut(x[x_idx], y[y_idx],
T tmp = UseIntermediateOut
? dx_op.UseIntermediateOut(x_val, y_val,
intermediate_out[tmp_out_idx],
out[offset], dout[offset])
: dx_op.Recompute(x[x_idx], y[y_idx], out[offset], dout[offset]);
: dx_op.Recompute(x_val, y_val, out[offset], dout[offset]);
if (BcastY) {
dx[x_idx] = tmp;
......@@ -2640,12 +2652,11 @@ static __global__ void FusedElemwiseAndActGradBroadcast2CUDAKernel(
}
}
if (dy != nullptr) {
T tmp =
UseIntermediateOut
? dy_op.UseIntermediateOut(x[x_idx], y[y_idx],
T tmp = UseIntermediateOut
? dy_op.UseIntermediateOut(x_val, y_val,
intermediate_out[tmp_out_idx],
out[offset], dout[offset])
: dy_op.Recompute(x[x_idx], y[y_idx], out[offset], dout[offset]);
: dy_op.Recompute(x_val, y_val, out[offset], dout[offset]);
if (BcastY) {
val += tmp;
} else {
......@@ -2655,9 +2666,9 @@ static __global__ void FusedElemwiseAndActGradBroadcast2CUDAKernel(
if (d_intermediate != nullptr) {
T tmp = UseIntermediateOut
? dintermediate_op.UseIntermediateOut(
y[y_idx], intermediate_out[tmp_out_idx], out[offset],
y_val, intermediate_out[tmp_out_idx], out[offset],
dout[offset])
: dintermediate_op.Recompute(x[x_idx], y[y_idx], out[offset],
: dintermediate_op.Recompute(x_val, y_val, out[offset],
dout[offset]);
if (SameShapeOfIntermediateOutAndOut) {
d_intermediate[tmp_out_idx] = tmp;
......@@ -2730,16 +2741,20 @@ void FusedElemwiseAndActGradComputeWithBroadcast(
int pre, n, post, is_run_common_broadcast;
get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post, &is_run_common_broadcast);
const T *x_data = nullptr;
const T *y_data = nullptr;
if (x->IsInitialized()) x_data = x->data<T>();
if (y->IsInitialized()) y_data = y->data<T>();
if (post == 1) {
int h = pre;
int w = n;
if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef __NVCC__
FusedElemwiseAndActGradBroadcast1CUDA<T, DX_OP, DY_OP, DIntermediate_OP,
UseIntermediateOut, BcastY,
SameShapeOfIntermediateOutAndOut>(
ctx.template device_context<DeviceContext>().stream(), x->data<T>(),
y->data<T>(),
ctx.template device_context<DeviceContext>().stream(), x_data, y_data,
intermediate_out == nullptr ? nullptr : intermediate_out->data<T>(),
out->data<T>(), dout->data<T>(), h, w, dx_op, dy_op, dintermediate_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
......@@ -2751,7 +2766,7 @@ void FusedElemwiseAndActGradComputeWithBroadcast(
FusedElemwiseAndActGradBroadcast1CPU<T, DX_OP, DY_OP, DIntermediate_OP,
UseIntermediateOut, BcastY,
SameShapeOfIntermediateOutAndOut>(
x->data<T>(), y->data<T>(),
x_data, y_data,
intermediate_out == nullptr ? nullptr : intermediate_out->data<T>(),
out->data<T>(), dout->data<T>(), h, w, dx_op, dy_op, dintermediate_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
......@@ -2765,8 +2780,7 @@ void FusedElemwiseAndActGradComputeWithBroadcast(
FusedElemwiseAndActGradBroadcast2CUDA<T, DX_OP, DY_OP, DIntermediate_OP,
UseIntermediateOut, BcastY,
SameShapeOfIntermediateOutAndOut>(
ctx.template device_context<DeviceContext>().stream(), x->data<T>(),
y->data<T>(),
ctx.template device_context<DeviceContext>().stream(), x_data, y_data,
intermediate_out == nullptr ? nullptr : intermediate_out->data<T>(),
out->data<T>(), dout->data<T>(), pre, n, post, dx_op, dy_op,
dintermediate_op,
......@@ -2779,7 +2793,7 @@ void FusedElemwiseAndActGradComputeWithBroadcast(
FusedElemwiseAndActGradBroadcast2CPU<T, DX_OP, DY_OP, DIntermediate_OP,
UseIntermediateOut, BcastY,
SameShapeOfIntermediateOutAndOut>(
x->data<T>(), y->data<T>(),
x_data, y_data,
intermediate_out == nullptr ? nullptr : intermediate_out->data<T>(),
out->data<T>(), dout->data<T>(), pre, n, post, dx_op, dy_op,
dintermediate_op,
......
......@@ -287,6 +287,15 @@ class FusedElemwiseActivationGradMaker
}
};
class FusedElemwiseAddActivationMaker : public FusedElemwiseActivationMaker {};
template <typename T>
class FusedElemwiseAddActivationGradMaker
: public FusedElemwiseActivationGradMaker<T> {
public:
using FusedElemwiseActivationGradMaker<T>::FusedElemwiseActivationGradMaker;
};
class FusedElemwiseActivationOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -361,10 +370,61 @@ class FusedElemwiseActivationOpGrad : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Y"), ctx.GetPlace());
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.GetPlace());
}
};
class FusedElemwiseAddActivationOp : public FusedElemwiseActivationOp {
public:
using FusedElemwiseActivationOp::FusedElemwiseActivationOp;
void InferShape(framework::InferShapeContext *ctx) const override {
FusedElemwiseActivationOp::InferShape(ctx);
std::vector<std::string> functor_names =
ctx->Attrs().Get<std::vector<std::string>>("functor_list");
bool elemntwise_add_detected = false;
for (auto names : functor_names) {
if (names == "elementwise_add") {
elemntwise_add_detected = true;
break;
}
}
PADDLE_ENFORCE_EQ(
elemntwise_add_detected, true,
platform::errors::InvalidArgument(
"When the FusedElemwiseAddActivationOp Is used in fused pass, the "
"elementwise_add Op must be"
"detected and used, Please check the fuse pass pattern"));
}
};
class FusedElemwiseAddActivationOpGrad : public FusedElemwiseActivationOpGrad {
public:
using FusedElemwiseActivationOpGrad::FusedElemwiseActivationOpGrad;
void InferShape(framework::InferShapeContext *ctx) const override {
FusedElemwiseActivationOpGrad::InferShape(ctx);
std::vector<std::string> functor_names =
ctx->Attrs().Get<std::vector<std::string>>("functor_list");
bool elemntwise_add_grad_detected = false;
for (auto names : functor_names) {
if (names == "elementwise_add_grad") {
elemntwise_add_grad_detected = true;
break;
}
}
PADDLE_ENFORCE_EQ(
elemntwise_add_grad_detected, true,
platform::errors::InvalidArgument(
"When the FusedElemwiseAddActivationOpGrad Is used in fused pass, "
"the elementwise_add_grad Op must be"
"detected and used, Please check the fuse pass pattern"));
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(
FusedElemwiseAddActivationNoNeddBufVarInferer, "X", "Y");
} // namespace operators
} // namespace paddle
......@@ -390,3 +450,27 @@ REGISTER_OP_CPU_KERNEL(
float>,
ops::FusedElemwiseActivationGradKernel<paddle::platform::CPUDeviceContext,
double>);
// for memory optimization, we register the fused_elemwise_add_activation OP
REGISTER_OPERATOR(
fused_elemwise_add_activation, ops::FusedElemwiseAddActivationOp,
ops::FusedElemwiseAddActivationMaker,
ops::FusedElemwiseAddActivationGradMaker<paddle::framework::OpDesc>,
ops::FusedElemwiseAddActivationGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(fused_elemwise_add_activation_grad,
ops::FusedElemwiseAddActivationNoNeddBufVarInferer,
ops::FusedElemwiseAddActivationOpGrad);
REGISTER_OP_CPU_KERNEL(
fused_elemwise_add_activation,
ops::FusedElemwiseActivationKernel<paddle::platform::CPUDeviceContext,
float>,
ops::FusedElemwiseActivationKernel<paddle::platform::CPUDeviceContext,
double>);
REGISTER_OP_CPU_KERNEL(
fused_elemwise_add_activation_grad,
ops::FusedElemwiseActivationGradKernel<paddle::platform::CPUDeviceContext,
float>,
ops::FusedElemwiseActivationGradKernel<paddle::platform::CPUDeviceContext,
double>);
......@@ -32,3 +32,21 @@ REGISTER_OP_CUDA_KERNEL(
double>,
ops::FusedElemwiseActivationGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);
REGISTER_OP_CUDA_KERNEL(
fused_elemwise_add_activation,
ops::FusedElemwiseActivationKernel<paddle::platform::CUDADeviceContext,
float>,
ops::FusedElemwiseActivationKernel<paddle::platform::CUDADeviceContext,
double>,
ops::FusedElemwiseActivationKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);
REGISTER_OP_CUDA_KERNEL(
fused_elemwise_add_activation_grad,
ops::FusedElemwiseActivationGradKernel<paddle::platform::CUDADeviceContext,
float>,
ops::FusedElemwiseActivationGradKernel<paddle::platform::CUDADeviceContext,
double>,
ops::FusedElemwiseActivationGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);
......@@ -77,4 +77,6 @@ class TestMNIST(TestParallelExecutorBase):
if __name__ == '__main__':
import paddle
paddle.enable_static()
unittest.main()
......@@ -390,4 +390,6 @@ for mode in {0, 1}:
grad_chek=False)
if __name__ == '__main__':
import paddle
paddle.enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册