未验证 提交 b207b8a7 编写于 作者: W wangchaochaohu 提交者: GitHub

[cherry-pick]memory optimization for fuse pattern of elemwise_add + act (#30303)

* reduce the  occupied size  of memory for the fused pattern of elementwise_add Op and activation Op(relu Op for example) (#29885)

* register OPMaker and Infer Shape Check for fused_elementwise_add (#30259)
上级 2db79f0a
...@@ -183,7 +183,7 @@ ir::Graph *FuseElewiseAddActPass::FuseElewiseAddActInplaceGrad( ...@@ -183,7 +183,7 @@ ir::Graph *FuseElewiseAddActPass::FuseElewiseAddActInplaceGrad(
std::string d_ele_y_n = d_ele_y->Name(); std::string d_ele_y_n = d_ele_y->Name();
OpDesc desc; OpDesc desc;
desc.SetType("fused_elemwise_activation_grad"); desc.SetType("fused_elemwise_add_activation_grad");
desc.SetInput("IntermediateOut", {}); desc.SetInput("IntermediateOut", {});
desc.SetInput("X", {}); desc.SetInput("X", {});
desc.SetInput("Y", std::vector<std::string>({ele_y_n})); desc.SetInput("Y", std::vector<std::string>({ele_y_n}));
...@@ -231,7 +231,7 @@ Node *FuseElewiseAddActPass::CreateFuseElewiseAddActNode( ...@@ -231,7 +231,7 @@ Node *FuseElewiseAddActPass::CreateFuseElewiseAddActNode(
desc.SetInput("Y", std::vector<std::string>({ele_y_n})); desc.SetInput("Y", std::vector<std::string>({ele_y_n}));
desc.SetOutput("Out", std::vector<std::string>({act_out_n})); desc.SetOutput("Out", std::vector<std::string>({act_out_n}));
desc.SetOutput("IntermediateOut", std::vector<std::string>({ele_out_n})); desc.SetOutput("IntermediateOut", std::vector<std::string>({ele_out_n}));
desc.SetType("fused_elemwise_activation"); desc.SetType("fused_elemwise_add_activation");
desc.SetAttr("save_intermediate_out", true); desc.SetAttr("save_intermediate_out", true);
desc.SetAttr("functor_list", std::vector<std::string>( desc.SetAttr("functor_list", std::vector<std::string>(
{op_1->Op()->Type(), op_2->Op()->Type()})); {op_1->Op()->Type(), op_2->Op()->Type()}));
...@@ -251,7 +251,7 @@ void FuseElewiseAddActPass::RemoveIntermediateOut(Graph *graph) const { ...@@ -251,7 +251,7 @@ void FuseElewiseAddActPass::RemoveIntermediateOut(Graph *graph) const {
std::unordered_set<const Node *> need_removed_nodes; std::unordered_set<const Node *> need_removed_nodes;
for (auto &cur_node : graph->Nodes()) { for (auto &cur_node : graph->Nodes()) {
if (cur_node->IsVar()) continue; if (cur_node->IsVar()) continue;
if (cur_node->Name() == "fused_elemwise_activation") { if (cur_node->Name() == "fused_elemwise_add_activation") {
bool save_intermediate_out = BOOST_GET_CONST( bool save_intermediate_out = BOOST_GET_CONST(
bool, cur_node->Op()->GetAttr("save_intermediate_out")); bool, cur_node->Op()->GetAttr("save_intermediate_out"));
auto intermediate_out_args = cur_node->Op()->Output("IntermediateOut"); auto intermediate_out_args = cur_node->Op()->Output("IntermediateOut");
...@@ -272,7 +272,7 @@ void FuseElewiseAddActPass::RemoveIntermediateOut(Graph *graph) const { ...@@ -272,7 +272,7 @@ void FuseElewiseAddActPass::RemoveIntermediateOut(Graph *graph) const {
} }
} }
} }
} else if (cur_node->Name() == "fused_elemwise_activation_grad") { } else if (cur_node->Name() == "fused_elemwise_add_activation_grad") {
auto intermediate_out_grad_args = auto intermediate_out_grad_args =
cur_node->Op()->Output(GradVarName("IntermediateOut")); cur_node->Op()->Output(GradVarName("IntermediateOut"));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
......
...@@ -2273,8 +2273,9 @@ template <typename T, typename DX_OP, typename DY_OP, typename DIntermediate_OP, ...@@ -2273,8 +2273,9 @@ template <typename T, typename DX_OP, typename DY_OP, typename DIntermediate_OP,
bool UseIntermediateOut> bool UseIntermediateOut>
struct FusedElemwiseAndActGradNoBroadcast { struct FusedElemwiseAndActGradNoBroadcast {
HOSTDEVICE void operator()(size_t i) { HOSTDEVICE void operator()(size_t i) {
T x_val = x_[i]; T zero = static_cast<T>(0);
T y_val = y_[i]; T x_val = (x_ == nullptr) ? zero : x_[i];
T y_val = (y_ == nullptr) ? zero : y_[i];
T out_val = out_[i]; T out_val = out_[i];
T dout_val = dout_[i]; T dout_val = dout_[i];
T intermediate_out_val = UseIntermediateOut T intermediate_out_val = UseIntermediateOut
...@@ -2320,16 +2321,19 @@ void FusedElemwiseAndActGradComputeNoBroadcast( ...@@ -2320,16 +2321,19 @@ void FusedElemwiseAndActGradComputeNoBroadcast(
size_t N = static_cast<size_t>(framework::product(x_dim)); size_t N = static_cast<size_t>(framework::product(x_dim));
platform::ForRange<DeviceContext> for_range( platform::ForRange<DeviceContext> for_range(
ctx.template device_context<DeviceContext>(), N); ctx.template device_context<DeviceContext>(), N);
for_range( const T *x_data = nullptr;
FusedElemwiseAndActGradNoBroadcast<T, DX_OP, DY_OP, DIntermediate_OP, const T *y_data = nullptr;
UseIntermediateOut>{ if (x->IsInitialized()) x_data = x->data<T>();
x->data<T>(), y->data<T>(), if (y->IsInitialized()) y_data = y->data<T>();
intermediate_out ? intermediate_out->data<T>() : nullptr,
out->data<T>(), dout->data<T>(), dx_op, dy_op, dintermediate_op, for_range(FusedElemwiseAndActGradNoBroadcast<
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()), T, DX_OP, DY_OP, DIntermediate_OP, UseIntermediateOut>{
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()), x_data, y_data, intermediate_out ? intermediate_out->data<T>() : nullptr,
dintermediate == nullptr ? nullptr : dintermediate->mutable_data<T>( out->data<T>(), dout->data<T>(), dx_op, dy_op, dintermediate_op,
ctx.GetPlace())}); dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()),
dintermediate == nullptr ? nullptr : dintermediate->mutable_data<T>(
ctx.GetPlace())});
} }
template <typename T, typename DX_OP, typename DY_OP, typename DIntermediate_OP, template <typename T, typename DX_OP, typename DY_OP, typename DIntermediate_OP,
...@@ -2340,6 +2344,7 @@ static void FusedElemwiseAndActGradBroadcast1CPU( ...@@ -2340,6 +2344,7 @@ static void FusedElemwiseAndActGradBroadcast1CPU(
const T *dout, int h, int w, DX_OP dx_op, DY_OP dy_op, const T *dout, int h, int w, DX_OP dx_op, DY_OP dy_op,
DIntermediate_OP dintermediate_op, T *dx, T *dy, T *d_intermediate) { DIntermediate_OP dintermediate_op, T *dx, T *dy, T *d_intermediate) {
int64_t tmp_out_idx, x_idx, y_idx; int64_t tmp_out_idx, x_idx, y_idx;
T zero = static_cast<T>(0);
for (int i = 0; i < h; ++i) { for (int i = 0; i < h; ++i) {
for (int j = 0; j < w; ++j) { for (int j = 0; j < w; ++j) {
int offset = i * w + j; int offset = i * w + j;
...@@ -2347,6 +2352,8 @@ static void FusedElemwiseAndActGradBroadcast1CPU( ...@@ -2347,6 +2352,8 @@ static void FusedElemwiseAndActGradBroadcast1CPU(
tmp_out_idx = BcastY ? j : offset; tmp_out_idx = BcastY ? j : offset;
y_idx = BcastY ? j : offset; y_idx = BcastY ? j : offset;
x_idx = BcastY ? offset : j; x_idx = BcastY ? offset : j;
T x_val = (x == nullptr) ? zero : x[x_idx];
T y_val = (y == nullptr) ? zero : y[y_idx];
if (SameShapeOfIntermediateOutAndOut) { if (SameShapeOfIntermediateOutAndOut) {
tmp_out_idx = offset; tmp_out_idx = offset;
...@@ -2354,11 +2361,10 @@ static void FusedElemwiseAndActGradBroadcast1CPU( ...@@ -2354,11 +2361,10 @@ static void FusedElemwiseAndActGradBroadcast1CPU(
if (dx != nullptr) { if (dx != nullptr) {
T tmp = UseIntermediateOut T tmp = UseIntermediateOut
? dx_op.UseIntermediateOut(x[x_idx], y[y_idx], ? dx_op.UseIntermediateOut(x_val, y_val,
intermediate_out[tmp_out_idx], intermediate_out[tmp_out_idx],
out[offset], dout[offset]) out[offset], dout[offset])
: dx_op.Recompute(x[x_idx], y[y_idx], out[offset], : dx_op.Recompute(x_val, y_val, out[offset], dout[offset]);
dout[offset]);
if (BcastY) { if (BcastY) {
dx[x_idx] = tmp; dx[x_idx] = tmp;
...@@ -2372,11 +2378,10 @@ static void FusedElemwiseAndActGradBroadcast1CPU( ...@@ -2372,11 +2378,10 @@ static void FusedElemwiseAndActGradBroadcast1CPU(
} }
if (dy != nullptr) { if (dy != nullptr) {
T tmp = UseIntermediateOut T tmp = UseIntermediateOut
? dy_op.UseIntermediateOut(x[x_idx], y[y_idx], ? dy_op.UseIntermediateOut(x_val, y_val,
intermediate_out[tmp_out_idx], intermediate_out[tmp_out_idx],
out[offset], dout[offset]) out[offset], dout[offset])
: dy_op.Recompute(x[x_idx], y[y_idx], out[offset], : dy_op.Recompute(x_val, y_val, out[offset], dout[offset]);
dout[offset]);
if (BcastY) { if (BcastY) {
if (i == 0) { if (i == 0) {
dy[y_idx] = tmp; dy[y_idx] = tmp;
...@@ -2390,10 +2395,10 @@ static void FusedElemwiseAndActGradBroadcast1CPU( ...@@ -2390,10 +2395,10 @@ static void FusedElemwiseAndActGradBroadcast1CPU(
if (d_intermediate != nullptr) { if (d_intermediate != nullptr) {
T tmp = UseIntermediateOut T tmp = UseIntermediateOut
? dintermediate_op.UseIntermediateOut( ? dintermediate_op.UseIntermediateOut(
x[x_idx], intermediate_out[tmp_out_idx], out[offset], x_val, intermediate_out[tmp_out_idx], out[offset],
dout[offset]) dout[offset])
: dintermediate_op.Recompute(x[x_idx], y[y_idx], : dintermediate_op.Recompute(x_val, y_val, out[offset],
out[offset], dout[i]); dout[i]);
if (SameShapeOfIntermediateOutAndOut) { if (SameShapeOfIntermediateOutAndOut) {
d_intermediate[tmp_out_idx] = tmp; d_intermediate[tmp_out_idx] = tmp;
} else { } else {
...@@ -2416,6 +2421,7 @@ static void FusedElemwiseAndActGradBroadcast2CPU( ...@@ -2416,6 +2421,7 @@ static void FusedElemwiseAndActGradBroadcast2CPU(
const T *dout, int pre, int n, int post, DX_OP dx_op, DY_OP dy_op, const T *dout, int pre, int n, int post, DX_OP dx_op, DY_OP dy_op,
DIntermediate_OP dintermediate_op, T *dx, T *dy, T *d_intermediate) { DIntermediate_OP dintermediate_op, T *dx, T *dy, T *d_intermediate) {
int64_t tmp_out_idx, x_idx, y_idx; int64_t tmp_out_idx, x_idx, y_idx;
T zero = static_cast<T>(0);
for (int i = 0; i < pre; ++i) { for (int i = 0; i < pre; ++i) {
for (int j = 0; j < n; ++j) { for (int j = 0; j < n; ++j) {
for (int k = 0; k < post; ++k) { for (int k = 0; k < post; ++k) {
...@@ -2425,17 +2431,20 @@ static void FusedElemwiseAndActGradBroadcast2CPU( ...@@ -2425,17 +2431,20 @@ static void FusedElemwiseAndActGradBroadcast2CPU(
y_idx = BcastY ? j : offset; y_idx = BcastY ? j : offset;
x_idx = BcastY ? offset : j; x_idx = BcastY ? offset : j;
T x_val = (x == nullptr) ? zero : x[x_idx];
T y_val = (y == nullptr) ? zero : y[y_idx];
if (SameShapeOfIntermediateOutAndOut) { if (SameShapeOfIntermediateOutAndOut) {
tmp_out_idx = offset; tmp_out_idx = offset;
} }
if (dx != nullptr) { if (dx != nullptr) {
T tmp = UseIntermediateOut T tmp =
? dx_op.UseIntermediateOut(x[x_idx], y[y_idx], UseIntermediateOut
intermediate_out[tmp_out_idx], ? dx_op.UseIntermediateOut(x_val, y_val,
out[offset], dout[offset]) intermediate_out[tmp_out_idx],
: dx_op.Recompute(x[x_idx], y[y_idx], out[offset], out[offset], dout[offset])
dout[offset]); : dx_op.Recompute(x_val, y_val, out[offset], dout[offset]);
if (BcastY) { if (BcastY) {
dx[x_idx] = tmp; dx[x_idx] = tmp;
...@@ -2448,12 +2457,12 @@ static void FusedElemwiseAndActGradBroadcast2CPU( ...@@ -2448,12 +2457,12 @@ static void FusedElemwiseAndActGradBroadcast2CPU(
} }
} }
if (dy != nullptr) { if (dy != nullptr) {
T tmp = UseIntermediateOut T tmp =
? dy_op.UseIntermediateOut(x[x_idx], y[y_idx], UseIntermediateOut
intermediate_out[tmp_out_idx], ? dy_op.UseIntermediateOut(x_val, y_val,
out[offset], dout[offset]) intermediate_out[tmp_out_idx],
: dy_op.Recompute(x[x_idx], y[y_idx], out[offset], out[offset], dout[offset])
dout[offset]); : dy_op.Recompute(x_val, y_val, out[offset], dout[offset]);
if (BcastY) { if (BcastY) {
if (i == 0 && k == 0) { if (i == 0 && k == 0) {
dy[y_idx] = tmp; dy[y_idx] = tmp;
...@@ -2467,10 +2476,10 @@ static void FusedElemwiseAndActGradBroadcast2CPU( ...@@ -2467,10 +2476,10 @@ static void FusedElemwiseAndActGradBroadcast2CPU(
if (d_intermediate != nullptr) { if (d_intermediate != nullptr) {
T tmp = UseIntermediateOut T tmp = UseIntermediateOut
? dintermediate_op.UseIntermediateOut( ? dintermediate_op.UseIntermediateOut(
x[x_idx], intermediate_out[tmp_out_idx], x_val, intermediate_out[tmp_out_idx], out[offset],
out[offset], dout[offset]) dout[offset])
: dintermediate_op.Recompute(x[x_idx], y[y_idx], : dintermediate_op.Recompute(x_val, y_val, out[offset],
out[offset], dout[i]); dout[i]);
if (SameShapeOfIntermediateOutAndOut) { if (SameShapeOfIntermediateOutAndOut) {
d_intermediate[tmp_out_idx] = tmp; d_intermediate[tmp_out_idx] = tmp;
} else { } else {
...@@ -2499,6 +2508,7 @@ static __global__ void FusedElemwiseAndActGradBroadcast1CUDAKernel( ...@@ -2499,6 +2508,7 @@ static __global__ void FusedElemwiseAndActGradBroadcast1CUDAKernel(
int tid = threadIdx.x; int tid = threadIdx.x;
T val(0), inter_val(0); T val(0), inter_val(0);
int64_t tmp_out_idx, x_idx, y_idx; int64_t tmp_out_idx, x_idx, y_idx;
T zero = static_cast<T>(0);
do { do {
int offset = i * w + j; int offset = i * w + j;
...@@ -2506,18 +2516,19 @@ static __global__ void FusedElemwiseAndActGradBroadcast1CUDAKernel( ...@@ -2506,18 +2516,19 @@ static __global__ void FusedElemwiseAndActGradBroadcast1CUDAKernel(
tmp_out_idx = BcastY ? j : offset; tmp_out_idx = BcastY ? j : offset;
y_idx = BcastY ? j : offset; y_idx = BcastY ? j : offset;
x_idx = BcastY ? offset : j; x_idx = BcastY ? offset : j;
T x_val = (x == nullptr) ? zero : x[x_idx];
T y_val = (y == nullptr) ? zero : y[y_idx];
if (SameShapeOfIntermediateOutAndOut) { if (SameShapeOfIntermediateOutAndOut) {
tmp_out_idx = offset; tmp_out_idx = offset;
} }
if (dx != nullptr) { if (dx != nullptr) {
T tmp = T tmp = UseIntermediateOut
UseIntermediateOut ? dx_op.UseIntermediateOut(x_val, y_val,
? dx_op.UseIntermediateOut(x[x_idx], y[y_idx], intermediate_out[tmp_out_idx],
intermediate_out[tmp_out_idx], out[offset], dout[offset])
out[offset], dout[offset]) : dx_op.Recompute(x_val, y_val, out[offset], dout[offset]);
: dx_op.Recompute(x[x_idx], y[y_idx], out[offset], dout[offset]);
if (BcastY) { if (BcastY) {
dx[x_idx] = tmp; dx[x_idx] = tmp;
...@@ -2526,12 +2537,11 @@ static __global__ void FusedElemwiseAndActGradBroadcast1CUDAKernel( ...@@ -2526,12 +2537,11 @@ static __global__ void FusedElemwiseAndActGradBroadcast1CUDAKernel(
} }
} }
if (dy != nullptr) { if (dy != nullptr) {
T tmp = T tmp = UseIntermediateOut
UseIntermediateOut ? dy_op.UseIntermediateOut(x_val, y_val,
? dy_op.UseIntermediateOut(x[x_idx], y[y_idx], intermediate_out[tmp_out_idx],
intermediate_out[tmp_out_idx], out[offset], dout[offset])
out[offset], dout[offset]) : dy_op.Recompute(x_val, y_val, out[offset], dout[offset]);
: dy_op.Recompute(x[x_idx], y[y_idx], out[offset], dout[offset]);
if (BcastY) { if (BcastY) {
val += tmp; val += tmp;
} else { } else {
...@@ -2543,7 +2553,7 @@ static __global__ void FusedElemwiseAndActGradBroadcast1CUDAKernel( ...@@ -2543,7 +2553,7 @@ static __global__ void FusedElemwiseAndActGradBroadcast1CUDAKernel(
? dintermediate_op.UseIntermediateOut( ? dintermediate_op.UseIntermediateOut(
y[y_idx], intermediate_out[tmp_out_idx], out[offset], y[y_idx], intermediate_out[tmp_out_idx], out[offset],
dout[offset]) dout[offset])
: dintermediate_op.Recompute(x[x_idx], y[y_idx], out[offset], : dintermediate_op.Recompute(x_val, y_val, out[offset],
dout[offset]); dout[offset]);
if (SameShapeOfIntermediateOutAndOut) { if (SameShapeOfIntermediateOutAndOut) {
d_intermediate[tmp_out_idx] = tmp; d_intermediate[tmp_out_idx] = tmp;
...@@ -2610,6 +2620,7 @@ static __global__ void FusedElemwiseAndActGradBroadcast2CUDAKernel( ...@@ -2610,6 +2620,7 @@ static __global__ void FusedElemwiseAndActGradBroadcast2CUDAKernel(
T val(0), inter_val(0); T val(0), inter_val(0);
int ttid = tid; int ttid = tid;
int64_t tmp_out_idx, x_idx, y_idx; int64_t tmp_out_idx, x_idx, y_idx;
T zero = static_cast<T>(0);
while (true) { while (true) {
int i = ttid / post; int i = ttid / post;
int k = ttid % post; int k = ttid % post;
...@@ -2620,18 +2631,19 @@ static __global__ void FusedElemwiseAndActGradBroadcast2CUDAKernel( ...@@ -2620,18 +2631,19 @@ static __global__ void FusedElemwiseAndActGradBroadcast2CUDAKernel(
tmp_out_idx = BcastY ? j : offset; tmp_out_idx = BcastY ? j : offset;
y_idx = BcastY ? j : offset; y_idx = BcastY ? j : offset;
x_idx = BcastY ? offset : j; x_idx = BcastY ? offset : j;
T x_val = (x == nullptr) ? zero : x[x_idx];
T y_val = (y == nullptr) ? zero : y[y_idx];
if (SameShapeOfIntermediateOutAndOut) { if (SameShapeOfIntermediateOutAndOut) {
tmp_out_idx = offset; tmp_out_idx = offset;
} }
if (dx != nullptr) { if (dx != nullptr) {
T tmp = T tmp = UseIntermediateOut
UseIntermediateOut ? dx_op.UseIntermediateOut(x_val, y_val,
? dx_op.UseIntermediateOut(x[x_idx], y[y_idx], intermediate_out[tmp_out_idx],
intermediate_out[tmp_out_idx], out[offset], dout[offset])
out[offset], dout[offset]) : dx_op.Recompute(x_val, y_val, out[offset], dout[offset]);
: dx_op.Recompute(x[x_idx], y[y_idx], out[offset], dout[offset]);
if (BcastY) { if (BcastY) {
dx[x_idx] = tmp; dx[x_idx] = tmp;
...@@ -2640,12 +2652,11 @@ static __global__ void FusedElemwiseAndActGradBroadcast2CUDAKernel( ...@@ -2640,12 +2652,11 @@ static __global__ void FusedElemwiseAndActGradBroadcast2CUDAKernel(
} }
} }
if (dy != nullptr) { if (dy != nullptr) {
T tmp = T tmp = UseIntermediateOut
UseIntermediateOut ? dy_op.UseIntermediateOut(x_val, y_val,
? dy_op.UseIntermediateOut(x[x_idx], y[y_idx], intermediate_out[tmp_out_idx],
intermediate_out[tmp_out_idx], out[offset], dout[offset])
out[offset], dout[offset]) : dy_op.Recompute(x_val, y_val, out[offset], dout[offset]);
: dy_op.Recompute(x[x_idx], y[y_idx], out[offset], dout[offset]);
if (BcastY) { if (BcastY) {
val += tmp; val += tmp;
} else { } else {
...@@ -2655,9 +2666,9 @@ static __global__ void FusedElemwiseAndActGradBroadcast2CUDAKernel( ...@@ -2655,9 +2666,9 @@ static __global__ void FusedElemwiseAndActGradBroadcast2CUDAKernel(
if (d_intermediate != nullptr) { if (d_intermediate != nullptr) {
T tmp = UseIntermediateOut T tmp = UseIntermediateOut
? dintermediate_op.UseIntermediateOut( ? dintermediate_op.UseIntermediateOut(
y[y_idx], intermediate_out[tmp_out_idx], out[offset], y_val, intermediate_out[tmp_out_idx], out[offset],
dout[offset]) dout[offset])
: dintermediate_op.Recompute(x[x_idx], y[y_idx], out[offset], : dintermediate_op.Recompute(x_val, y_val, out[offset],
dout[offset]); dout[offset]);
if (SameShapeOfIntermediateOutAndOut) { if (SameShapeOfIntermediateOutAndOut) {
d_intermediate[tmp_out_idx] = tmp; d_intermediate[tmp_out_idx] = tmp;
...@@ -2730,16 +2741,20 @@ void FusedElemwiseAndActGradComputeWithBroadcast( ...@@ -2730,16 +2741,20 @@ void FusedElemwiseAndActGradComputeWithBroadcast(
int pre, n, post, is_run_common_broadcast; int pre, n, post, is_run_common_broadcast;
get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post, &is_run_common_broadcast); get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post, &is_run_common_broadcast);
const T *x_data = nullptr;
const T *y_data = nullptr;
if (x->IsInitialized()) x_data = x->data<T>();
if (y->IsInitialized()) y_data = y->data<T>();
if (post == 1) { if (post == 1) {
int h = pre; int h = pre;
int w = n; int w = n;
if (platform::is_gpu_place(ctx.GetPlace())) { if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef __NVCC__ #ifdef __NVCC__
FusedElemwiseAndActGradBroadcast1CUDA<T, DX_OP, DY_OP, DIntermediate_OP, FusedElemwiseAndActGradBroadcast1CUDA<T, DX_OP, DY_OP, DIntermediate_OP,
UseIntermediateOut, BcastY, UseIntermediateOut, BcastY,
SameShapeOfIntermediateOutAndOut>( SameShapeOfIntermediateOutAndOut>(
ctx.template device_context<DeviceContext>().stream(), x->data<T>(), ctx.template device_context<DeviceContext>().stream(), x_data, y_data,
y->data<T>(),
intermediate_out == nullptr ? nullptr : intermediate_out->data<T>(), intermediate_out == nullptr ? nullptr : intermediate_out->data<T>(),
out->data<T>(), dout->data<T>(), h, w, dx_op, dy_op, dintermediate_op, out->data<T>(), dout->data<T>(), h, w, dx_op, dy_op, dintermediate_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()), dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
...@@ -2751,7 +2766,7 @@ void FusedElemwiseAndActGradComputeWithBroadcast( ...@@ -2751,7 +2766,7 @@ void FusedElemwiseAndActGradComputeWithBroadcast(
FusedElemwiseAndActGradBroadcast1CPU<T, DX_OP, DY_OP, DIntermediate_OP, FusedElemwiseAndActGradBroadcast1CPU<T, DX_OP, DY_OP, DIntermediate_OP,
UseIntermediateOut, BcastY, UseIntermediateOut, BcastY,
SameShapeOfIntermediateOutAndOut>( SameShapeOfIntermediateOutAndOut>(
x->data<T>(), y->data<T>(), x_data, y_data,
intermediate_out == nullptr ? nullptr : intermediate_out->data<T>(), intermediate_out == nullptr ? nullptr : intermediate_out->data<T>(),
out->data<T>(), dout->data<T>(), h, w, dx_op, dy_op, dintermediate_op, out->data<T>(), dout->data<T>(), h, w, dx_op, dy_op, dintermediate_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()), dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
...@@ -2765,8 +2780,7 @@ void FusedElemwiseAndActGradComputeWithBroadcast( ...@@ -2765,8 +2780,7 @@ void FusedElemwiseAndActGradComputeWithBroadcast(
FusedElemwiseAndActGradBroadcast2CUDA<T, DX_OP, DY_OP, DIntermediate_OP, FusedElemwiseAndActGradBroadcast2CUDA<T, DX_OP, DY_OP, DIntermediate_OP,
UseIntermediateOut, BcastY, UseIntermediateOut, BcastY,
SameShapeOfIntermediateOutAndOut>( SameShapeOfIntermediateOutAndOut>(
ctx.template device_context<DeviceContext>().stream(), x->data<T>(), ctx.template device_context<DeviceContext>().stream(), x_data, y_data,
y->data<T>(),
intermediate_out == nullptr ? nullptr : intermediate_out->data<T>(), intermediate_out == nullptr ? nullptr : intermediate_out->data<T>(),
out->data<T>(), dout->data<T>(), pre, n, post, dx_op, dy_op, out->data<T>(), dout->data<T>(), pre, n, post, dx_op, dy_op,
dintermediate_op, dintermediate_op,
...@@ -2779,7 +2793,7 @@ void FusedElemwiseAndActGradComputeWithBroadcast( ...@@ -2779,7 +2793,7 @@ void FusedElemwiseAndActGradComputeWithBroadcast(
FusedElemwiseAndActGradBroadcast2CPU<T, DX_OP, DY_OP, DIntermediate_OP, FusedElemwiseAndActGradBroadcast2CPU<T, DX_OP, DY_OP, DIntermediate_OP,
UseIntermediateOut, BcastY, UseIntermediateOut, BcastY,
SameShapeOfIntermediateOutAndOut>( SameShapeOfIntermediateOutAndOut>(
x->data<T>(), y->data<T>(), x_data, y_data,
intermediate_out == nullptr ? nullptr : intermediate_out->data<T>(), intermediate_out == nullptr ? nullptr : intermediate_out->data<T>(),
out->data<T>(), dout->data<T>(), pre, n, post, dx_op, dy_op, out->data<T>(), dout->data<T>(), pre, n, post, dx_op, dy_op,
dintermediate_op, dintermediate_op,
......
...@@ -287,6 +287,15 @@ class FusedElemwiseActivationGradMaker ...@@ -287,6 +287,15 @@ class FusedElemwiseActivationGradMaker
} }
}; };
class FusedElemwiseAddActivationMaker : public FusedElemwiseActivationMaker {};
template <typename T>
class FusedElemwiseAddActivationGradMaker
: public FusedElemwiseActivationGradMaker<T> {
public:
using FusedElemwiseActivationGradMaker<T>::FusedElemwiseActivationGradMaker;
};
class FusedElemwiseActivationOpGrad : public framework::OperatorWithKernel { class FusedElemwiseActivationOpGrad : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -361,10 +370,61 @@ class FusedElemwiseActivationOpGrad : public framework::OperatorWithKernel { ...@@ -361,10 +370,61 @@ class FusedElemwiseActivationOpGrad : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
OperatorWithKernel::IndicateVarDataType(ctx, "Y"), ctx.GetPlace()); ctx, framework::GradVarName("Out")),
ctx.GetPlace());
}
};
class FusedElemwiseAddActivationOp : public FusedElemwiseActivationOp {
public:
using FusedElemwiseActivationOp::FusedElemwiseActivationOp;
void InferShape(framework::InferShapeContext *ctx) const override {
FusedElemwiseActivationOp::InferShape(ctx);
std::vector<std::string> functor_names =
ctx->Attrs().Get<std::vector<std::string>>("functor_list");
bool elemntwise_add_detected = false;
for (auto names : functor_names) {
if (names == "elementwise_add") {
elemntwise_add_detected = true;
break;
}
}
PADDLE_ENFORCE_EQ(
elemntwise_add_detected, true,
platform::errors::InvalidArgument(
"When the FusedElemwiseAddActivationOp Is used in fused pass, the "
"elementwise_add Op must be"
"detected and used, Please check the fuse pass pattern"));
}
};
class FusedElemwiseAddActivationOpGrad : public FusedElemwiseActivationOpGrad {
public:
using FusedElemwiseActivationOpGrad::FusedElemwiseActivationOpGrad;
void InferShape(framework::InferShapeContext *ctx) const override {
FusedElemwiseActivationOpGrad::InferShape(ctx);
std::vector<std::string> functor_names =
ctx->Attrs().Get<std::vector<std::string>>("functor_list");
bool elemntwise_add_grad_detected = false;
for (auto names : functor_names) {
if (names == "elementwise_add_grad") {
elemntwise_add_grad_detected = true;
break;
}
}
PADDLE_ENFORCE_EQ(
elemntwise_add_grad_detected, true,
platform::errors::InvalidArgument(
"When the FusedElemwiseAddActivationOpGrad Is used in fused pass, "
"the elementwise_add_grad Op must be"
"detected and used, Please check the fuse pass pattern"));
} }
}; };
DECLARE_NO_NEED_BUFFER_VARS_INFERER(
FusedElemwiseAddActivationNoNeddBufVarInferer, "X", "Y");
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -390,3 +450,27 @@ REGISTER_OP_CPU_KERNEL( ...@@ -390,3 +450,27 @@ REGISTER_OP_CPU_KERNEL(
float>, float>,
ops::FusedElemwiseActivationGradKernel<paddle::platform::CPUDeviceContext, ops::FusedElemwiseActivationGradKernel<paddle::platform::CPUDeviceContext,
double>); double>);
// for memory optimization, we register the fused_elemwise_add_activation OP
REGISTER_OPERATOR(
fused_elemwise_add_activation, ops::FusedElemwiseAddActivationOp,
ops::FusedElemwiseAddActivationMaker,
ops::FusedElemwiseAddActivationGradMaker<paddle::framework::OpDesc>,
ops::FusedElemwiseAddActivationGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(fused_elemwise_add_activation_grad,
ops::FusedElemwiseAddActivationNoNeddBufVarInferer,
ops::FusedElemwiseAddActivationOpGrad);
REGISTER_OP_CPU_KERNEL(
fused_elemwise_add_activation,
ops::FusedElemwiseActivationKernel<paddle::platform::CPUDeviceContext,
float>,
ops::FusedElemwiseActivationKernel<paddle::platform::CPUDeviceContext,
double>);
REGISTER_OP_CPU_KERNEL(
fused_elemwise_add_activation_grad,
ops::FusedElemwiseActivationGradKernel<paddle::platform::CPUDeviceContext,
float>,
ops::FusedElemwiseActivationGradKernel<paddle::platform::CPUDeviceContext,
double>);
...@@ -32,3 +32,21 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -32,3 +32,21 @@ REGISTER_OP_CUDA_KERNEL(
double>, double>,
ops::FusedElemwiseActivationGradKernel<paddle::platform::CUDADeviceContext, ops::FusedElemwiseActivationGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>); paddle::platform::float16>);
REGISTER_OP_CUDA_KERNEL(
fused_elemwise_add_activation,
ops::FusedElemwiseActivationKernel<paddle::platform::CUDADeviceContext,
float>,
ops::FusedElemwiseActivationKernel<paddle::platform::CUDADeviceContext,
double>,
ops::FusedElemwiseActivationKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);
REGISTER_OP_CUDA_KERNEL(
fused_elemwise_add_activation_grad,
ops::FusedElemwiseActivationGradKernel<paddle::platform::CUDADeviceContext,
float>,
ops::FusedElemwiseActivationGradKernel<paddle::platform::CUDADeviceContext,
double>,
ops::FusedElemwiseActivationGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);
...@@ -77,4 +77,6 @@ class TestMNIST(TestParallelExecutorBase): ...@@ -77,4 +77,6 @@ class TestMNIST(TestParallelExecutorBase):
if __name__ == '__main__': if __name__ == '__main__':
import paddle
paddle.enable_static()
unittest.main() unittest.main()
...@@ -390,4 +390,6 @@ for mode in {0, 1}: ...@@ -390,4 +390,6 @@ for mode in {0, 1}:
grad_chek=False) grad_chek=False)
if __name__ == '__main__': if __name__ == '__main__':
import paddle
paddle.enable_static()
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册