未验证 提交 11a1284c 编写于 作者: Q qingqing01 提交者: GitHub

Refine Infershape in activation_op for double_grad (#18731)

上级 45cb4795
......@@ -604,25 +604,48 @@ class ActivationOpDoubleGrad : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {
if (static_cast<int>(kDepValue) & static_cast<int>(kDepX)) {
// some op has no output DX, check HasOutputs("DX") here
if (HasOutputs("DX") && ctx->HasOutput("DX")) {
if (ctx->HasOutput("DX")) {
ctx->ShareDim("X", "DX");
ctx->ShareLoD("X", "DX");
}
// some op has no output DDout, check HasOutputs("DDout") here
if (HasOutputs("DDOut") && ctx->HasOutput("DDOut")) {
if (ctx->HasOutput("DDOut")) {
ctx->ShareDim("X", "DDOut");
ctx->ShareLoD("X", "DDOut");
}
}
if (static_cast<int>(kDepValue) & static_cast<int>(kDepOut)) {
// some op has no output DOut, check HasOutputs("DOut") here
if (HasOutputs("DOut") && ctx->HasOutput("DOut")) {
if (ctx->HasOutput("DOut")) {
ctx->ShareDim("Out", "DOut");
ctx->ShareLoD("Out", "DOut");
}
// some op has no output DDOut, check HasOutputs("DDOut") here
if (HasOutputs("DDOut") && ctx->HasOutput("DDOut")) {
if (ctx->HasOutput("DDOut")) {
ctx->ShareDim("Out", "DDOut");
ctx->ShareLoD("Out", "DDOut");
}
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return GetKernelType(ctx, *this, "DDX");
}
};
template <ActBwdOpFwdDeps kDepValue>
class ActivationOpDoubleGrad2 : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
if (static_cast<int>(kDepValue) & static_cast<int>(kDepX)) {
if (ctx->HasOutput("DDOut")) {
ctx->ShareDim("X", "DDOut");
ctx->ShareLoD("X", "DDOut");
}
}
if (static_cast<int>(kDepValue) & static_cast<int>(kDepOut)) {
if (ctx->HasOutput("DDOut")) {
ctx->ShareDim("Out", "DDOut");
ctx->ShareLoD("Out", "DDOut");
}
......@@ -775,7 +798,7 @@ REGISTER_OPERATOR(relu_grad, ops::ActivationOpGrad,
ops::ReluDoubleGradMaker);
REGISTER_OPERATOR(
relu_grad_grad,
ops::ActivationOpDoubleGrad<ops::ReluGradFunctor<float>::FwdDeps()>);
ops::ActivationOpDoubleGrad2<ops::ReluGradFunctor<float>::FwdDeps()>);
REGISTER_ACTIVATION_CPU_KERNEL(relu, Relu, ReluFunctor, ReluGradFunctor);
......@@ -800,7 +823,7 @@ REGISTER_OPERATOR(leaky_relu_grad, ops::ActivationOpGrad,
ops::LeakyReluDoubleGradMaker);
REGISTER_OPERATOR(
leaky_relu_grad_grad,
ops::ActivationOpDoubleGrad<ops::LeakyReluGradFunctor<float>::FwdDeps()>);
ops::ActivationOpDoubleGrad2<ops::LeakyReluGradFunctor<float>::FwdDeps()>);
REGISTER_ACTIVATION_CPU_KERNEL(leaky_relu, LeakyRelu, LeakyReluFunctor,
LeakyReluGradFunctor);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册