未验证 提交 710767d8 编写于 作者: Z Zeng Jinle 提交者: GitHub

Enable inplace support for some ops (#19612)

* enable inplace for affine_channel op, dropout op, test=develop

* remove dropout inplace for ngraph fails, test=develop
上级 a18cf5e1
......@@ -238,21 +238,11 @@ class AffineChannelGradKernel : public framework::OpKernel<T> {
EigenVectorArrayMap<T> dbias_e(dbias_d, C);
if (layout == framework::DataLayout::kNCHW) {
// compute dx
int stride = C * HxW;
if (dx) {
for (int i = 0; i < N; i++) {
ConstEigenArrayMap<T> dy_e(dy_d, HxW, C);
EigenArrayMap<T> dx_e(dx_d, HxW, C);
dx_e = dy_e.rowwise() * scale_e.transpose();
dy_d += stride;
dx_d += stride;
}
}
// compute dscale and dbias
int stride = C * HxW;
auto* original_dy_d = dy_d;
if (dscale && dbias) {
auto* x_d = x->data<T>();
dy_d = dy->data<T>();
for (int i = 0; i < N; i++) {
ConstEigenArrayMap<T> x_e(x_d, HxW, C);
ConstEigenArrayMap<T> dy_e(dy_d, HxW, C);
......@@ -270,14 +260,21 @@ class AffineChannelGradKernel : public framework::OpKernel<T> {
dy_d += stride;
}
}
} else {
int num = N * HxW;
ConstEigenArrayMap<T> dy_e(dy_d, C, num);
// compute dx
if (dx) {
EigenArrayMap<T> dx_e(dx_d, C, num);
dx_e = dy_e.colwise() * scale_e;
dy_d = original_dy_d;
for (int i = 0; i < N; i++) {
ConstEigenArrayMap<T> dy_e(dy_d, HxW, C);
EigenArrayMap<T> dx_e(dx_d, HxW, C);
dx_e = dy_e.rowwise() * scale_e.transpose();
dy_d += stride;
dx_d += stride;
}
}
} else {
int num = N * HxW;
ConstEigenArrayMap<T> dy_e(dy_d, C, num);
// compute dscale and dbias
if (dscale && dbias) {
auto* x_d = x->data<T>();
......@@ -285,6 +282,12 @@ class AffineChannelGradKernel : public framework::OpKernel<T> {
dscale_e = (x_e * dy_e).rowwise().sum();
dbias_e = dy_e.rowwise().sum();
}
// compute dx
if (dx) {
EigenArrayMap<T> dx_e(dx_d, C, num);
dx_e = dy_e.colwise() * scale_e;
}
}
}
};
......@@ -316,6 +319,11 @@ class AffineChannelNoNeedBufferVarsInference
}
};
DECLARE_INPLACE_OP_INFERER(AffineChannelInplaceInferer, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(AffineChannelGradInplaceInferer,
{framework::GradVarName("Out"),
framework::GradVarName("X")});
} // namespace operators
} // namespace paddle
......@@ -323,9 +331,11 @@ namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR(affine_channel, ops::AffineChannelOp,
ops::AffineChannelOpMaker, ops::AffineChannelGradMaker);
ops::AffineChannelOpMaker, ops::AffineChannelGradMaker,
ops::AffineChannelInplaceInferer);
REGISTER_OPERATOR(affine_channel_grad, ops::AffineChannelOpGrad,
ops::AffineChannelNoNeedBufferVarsInference);
ops::AffineChannelNoNeedBufferVarsInference,
ops::AffineChannelGradInplaceInferer);
REGISTER_OP_CPU_KERNEL(affine_channel, ops::AffineChannelKernel<CPU, float>,
ops::AffineChannelKernel<CPU, double>);
......
......@@ -151,11 +151,6 @@ class AffineChannelGradCUDAKernel : public framework::OpKernel<T> {
int grid1 = (num + block - 1) / block;
int grid2 = std::min(C, max_blocks);
if (layout == framework::DataLayout::kNCHW) {
if (dx) {
KeAffineChannelCUDA<T, framework::DataLayout::kNCHW,
false><<<grid1, block, 0, dev_ctx.stream()>>>(
dy_d, s_d, nullptr, C, HxW, num, dx_d);
}
if (dscale && dbias) {
const T* x_d = x->data<T>();
AffineChannelScaleBiasGradientCUDAKernel<
......@@ -163,12 +158,12 @@ class AffineChannelGradCUDAKernel : public framework::OpKernel<T> {
dev_ctx.stream()>>>(
dy_d, x_d, N, C, HxW, ds_d, db_d);
}
} else {
if (dx) {
KeAffineChannelCUDA<T, framework::DataLayout::kNHWC,
KeAffineChannelCUDA<T, framework::DataLayout::kNCHW,
false><<<grid1, block, 0, dev_ctx.stream()>>>(
dy_d, s_d, nullptr, C, HxW, num, dx_d);
}
} else {
if (dscale && dbias) {
const T* x_d = x->data<T>();
AffineChannelScaleBiasGradientCUDAKernel<
......@@ -176,6 +171,12 @@ class AffineChannelGradCUDAKernel : public framework::OpKernel<T> {
dev_ctx.stream()>>>(
dy_d, x_d, N, C, HxW, ds_d, db_d);
}
if (dx) {
KeAffineChannelCUDA<T, framework::DataLayout::kNHWC,
false><<<grid1, block, 0, dev_ctx.stream()>>>(
dy_d, s_d, nullptr, C, HxW, num, dx_d);
}
}
}
};
......
......@@ -93,13 +93,18 @@ class ClipGradOpDescMaker : public framework::SingleGradOpDescMaker {
}
};
DECLARE_INPLACE_OP_INFERER(ClipInplaceInferer, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(ClipGradInplaceInferer,
{framework::GradVarName("Out"),
framework::GradVarName("X")});
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(clip, ops::ClipOp, ops::ClipOpMaker<float>,
ops::ClipGradOpDescMaker);
REGISTER_OPERATOR(clip_grad, ops::ClipOpGrad);
ops::ClipGradOpDescMaker, ops::ClipInplaceInferer);
REGISTER_OPERATOR(clip_grad, ops::ClipOpGrad, ops::ClipGradInplaceInferer);
REGISTER_OP_CPU_KERNEL(
clip, ops::ClipKernel<paddle::platform::CPUDeviceContext, float>);
REGISTER_OP_CPU_KERNEL(
......
......@@ -220,14 +220,33 @@ class SoftmaxOpGradMaker : public framework::SingleGradOpDescMaker {
}
};
DECLARE_INPLACE_OP_INFERER(SoftmaxInplaceInferer, {"X", "Out"});
class SoftmaxGradInplaceInferer final : public framework::InplaceOpInference {
public:
using framework::InplaceOpInference::InplaceOpInference;
std::unordered_map<std::string, std::string> operator()(
const framework::OpDesc& op_desc, bool use_cuda) const final {
if (use_cuda) {
return {{"Out", framework::GradVarName("X")}};
} else {
// NOTE(zjl): AVX implementation of SoftmaxGrad does not support in-place
return {};
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker,
ops::SoftmaxOpInferVarType, ops::SoftmaxOpGradMaker);
REGISTER_OPERATOR(softmax_grad, ops::SoftmaxOpGrad);
ops::SoftmaxOpInferVarType, ops::SoftmaxOpGradMaker,
ops::SoftmaxInplaceInferer);
REGISTER_OPERATOR(softmax_grad, ops::SoftmaxOpGrad,
ops::SoftmaxGradInplaceInferer);
REGISTER_OP_CPU_KERNEL(
softmax, ops::SoftmaxKernel<paddle::platform::CPUDeviceContext, float>,
ops::SoftmaxKernel<paddle::platform::CPUDeviceContext, double>);
......
......@@ -287,12 +287,19 @@ class Squeeze2GradOp : public framework::OperatorBase {
attrs["shape"] = framework::vectorize<int>(x_dims);
auto reshape_op = framework::OpRegistry::CreateOp(
"reshape2", {{"X", {dout_name}}, {"Shape", {}}},
{{"Out", {dx_name}}, {"XShape", {xshape_name}}}, attrs);
"reshape2_grad", {{framework::GradVarName("Out"), {dout_name}},
{"Shape", {}},
{"XShape", {xshape_name}}},
{{framework::GradVarName("X"), {dx_name}}}, attrs);
reshape_op->Run(scope, place);
}
};
DECLARE_INPLACE_OP_INFERER(SequeezeInplaceInferer, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(SequeezeGradInplaceInferer,
{framework::GradVarName("Out"),
framework::GradVarName("X")});
} // namespace operators
} // namespace paddle
......@@ -306,6 +313,7 @@ REGISTER_OPERATOR(squeeze, ops::SqueezeOp, ops::SqueezeOpMaker,
REGISTER_OPERATOR(squeeze_grad, ops::SqueezeGradOp, ops::SqueezeGradInferShape);
REGISTER_OPERATOR(squeeze2, ops::Squeeze2Op, ops::Squeeze2OpMaker,
ops::Squeeze2OpInferShape, ops::Squeeze2GradOpMaker);
ops::Squeeze2OpInferShape, ops::Squeeze2GradOpMaker,
ops::SequeezeInplaceInferer);
REGISTER_OPERATOR(squeeze2_grad, ops::Squeeze2GradOp,
ops::Squeeze2GradInferShape);
ops::Squeeze2GradInferShape, ops::SequeezeGradInplaceInferer);
......@@ -269,11 +269,19 @@ class Unsqueeze2GradOp : public framework::OperatorBase {
attrs["shape"] = framework::vectorize2int(x_dims);
auto reshape_op = framework::OpRegistry::CreateOp(
"reshape2", {{"X", {dout_name}}, {"Shape", {}}},
{{"Out", {dx_name}}, {"XShape", {xshape_name}}}, attrs);
"reshape2_grad", {{framework::GradVarName("Out"), {dout_name}},
{"Shape", {}},
{"XShape", {xshape_name}}},
{{framework::GradVarName("X"), {dx_name}}}, attrs);
reshape_op->Run(scope, place);
}
};
DECLARE_INPLACE_OP_INFERER(UnsqueezeInplaceInferer, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(UnsqueezeGradInplaceInferer,
{framework::GradVarName("Out"),
framework::GradVarName("X")});
} // namespace operators
} // namespace paddle
......@@ -288,6 +296,8 @@ REGISTER_OPERATOR(unsqueeze_grad, ops::UnsqueezeGradOp,
ops::UnsqueezeGradInferShape);
REGISTER_OPERATOR(unsqueeze2, ops::Unsqueeze2Op, ops::Unsqueeze2OpMaker,
ops::Unsqueeze2OpInferShape, ops::Unsqueeze2GradOpMaker);
ops::Unsqueeze2OpInferShape, ops::Unsqueeze2GradOpMaker,
ops::UnsqueezeInplaceInferer);
REGISTER_OPERATOR(unsqueeze2_grad, ops::Unsqueeze2GradOp,
ops::Unsqueeze2GradInferShape);
ops::Unsqueeze2GradInferShape,
ops::UnsqueezeGradInplaceInferer);
......@@ -109,8 +109,8 @@ class TestAffineChannelNCHWLargeShape(TestAffineChannelOp):
class TestAffineChannelNHWCLargeShape(TestAffineChannelNCHWLargeShape):
def init_test_case(self):
self.shape = [64, 32, 32, 512]
self.C = 512
self.shape = [64, 32, 32, 128]
self.C = 128
self.layout = 'NHWC'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册