未验证 提交 7ac4818a 编写于 作者: Q qingqing01 提交者: GitHub

Refine Infershape in activation_op for double_grad. (#18485)

* Refine Infershape in activation_op for double_grad.
上级 602cb6a5
...@@ -604,21 +604,48 @@ class ActivationOpDoubleGrad : public framework::OperatorWithKernel { ...@@ -604,21 +604,48 @@ class ActivationOpDoubleGrad : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
if (static_cast<int>(kDepValue) & static_cast<int>(kDepX)) { if (static_cast<int>(kDepValue) & static_cast<int>(kDepX)) {
if (HasOutputs("DX") && ctx->HasOutput("DX")) { if (ctx->HasOutput("DX")) {
ctx->ShareDim("X", "DX"); ctx->ShareDim("X", "DX");
ctx->ShareLoD("X", "DX"); ctx->ShareLoD("X", "DX");
} }
if (HasOutputs("DDOut") && ctx->HasOutput("DDOut")) { if (ctx->HasOutput("DDOut")) {
ctx->ShareDim("X", "DDOut"); ctx->ShareDim("X", "DDOut");
ctx->ShareLoD("X", "DDOut"); ctx->ShareLoD("X", "DDOut");
} }
} }
if (static_cast<int>(kDepValue) & static_cast<int>(kDepOut)) { if (static_cast<int>(kDepValue) & static_cast<int>(kDepOut)) {
if (HasOutputs("DOut") && ctx->HasOutput("DOut")) { if (ctx->HasOutput("DOut")) {
ctx->ShareDim("Out", "DOut"); ctx->ShareDim("Out", "DOut");
ctx->ShareLoD("Out", "DOut"); ctx->ShareLoD("Out", "DOut");
} }
if (HasOutputs("DDOut") && ctx->HasOutput("DDOut")) { if (ctx->HasOutput("DDOut")) {
ctx->ShareDim("Out", "DDOut");
ctx->ShareLoD("Out", "DDOut");
}
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return GetKernelType(ctx, *this, "DDX");
}
};
template <ActBwdOpFwdDeps kDepValue>
class ActivationOpDoubleGrad2 : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
if (static_cast<int>(kDepValue) & static_cast<int>(kDepX)) {
if (ctx->HasOutput("DDOut")) {
ctx->ShareDim("X", "DDOut");
ctx->ShareLoD("X", "DDOut");
}
}
if (static_cast<int>(kDepValue) & static_cast<int>(kDepOut)) {
if (ctx->HasOutput("DDOut")) {
ctx->ShareDim("Out", "DDOut"); ctx->ShareDim("Out", "DDOut");
ctx->ShareLoD("Out", "DDOut"); ctx->ShareLoD("Out", "DDOut");
} }
...@@ -771,7 +798,7 @@ REGISTER_OPERATOR(relu_grad, ops::ActivationOpGrad, ...@@ -771,7 +798,7 @@ REGISTER_OPERATOR(relu_grad, ops::ActivationOpGrad,
ops::ReluDoubleGradMaker); ops::ReluDoubleGradMaker);
REGISTER_OPERATOR( REGISTER_OPERATOR(
relu_grad_grad, relu_grad_grad,
ops::ActivationOpDoubleGrad<ops::ReluGradFunctor<float>::FwdDeps()>); ops::ActivationOpDoubleGrad2<ops::ReluGradFunctor<float>::FwdDeps()>);
REGISTER_ACTIVATION_CPU_KERNEL(relu, Relu, ReluFunctor, ReluGradFunctor); REGISTER_ACTIVATION_CPU_KERNEL(relu, Relu, ReluFunctor, ReluGradFunctor);
...@@ -796,7 +823,7 @@ REGISTER_OPERATOR(leaky_relu_grad, ops::ActivationOpGrad, ...@@ -796,7 +823,7 @@ REGISTER_OPERATOR(leaky_relu_grad, ops::ActivationOpGrad,
ops::LeakyReluDoubleGradMaker); ops::LeakyReluDoubleGradMaker);
REGISTER_OPERATOR( REGISTER_OPERATOR(
leaky_relu_grad_grad, leaky_relu_grad_grad,
ops::ActivationOpDoubleGrad<ops::LeakyReluGradFunctor<float>::FwdDeps()>); ops::ActivationOpDoubleGrad2<ops::LeakyReluGradFunctor<float>::FwdDeps()>);
REGISTER_ACTIVATION_CPU_KERNEL(leaky_relu, LeakyRelu, LeakyReluFunctor, REGISTER_ACTIVATION_CPU_KERNEL(leaky_relu, LeakyRelu, LeakyReluFunctor,
LeakyReluGradFunctor); LeakyReluGradFunctor);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册