未验证 提交 7ac4818a 编写于 作者: Q qingqing01 提交者: GitHub

Refine Infershape in activation_op for double_grad. (#18485)

* Refine Infershape in activation_op for double_grad.
上级 602cb6a5
无相关合并请求
......@@ -604,21 +604,48 @@ class ActivationOpDoubleGrad : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {
if (static_cast<int>(kDepValue) & static_cast<int>(kDepX)) {
if (HasOutputs("DX") && ctx->HasOutput("DX")) {
if (ctx->HasOutput("DX")) {
ctx->ShareDim("X", "DX");
ctx->ShareLoD("X", "DX");
}
if (HasOutputs("DDOut") && ctx->HasOutput("DDOut")) {
if (ctx->HasOutput("DDOut")) {
ctx->ShareDim("X", "DDOut");
ctx->ShareLoD("X", "DDOut");
}
}
if (static_cast<int>(kDepValue) & static_cast<int>(kDepOut)) {
if (HasOutputs("DOut") && ctx->HasOutput("DOut")) {
if (ctx->HasOutput("DOut")) {
ctx->ShareDim("Out", "DOut");
ctx->ShareLoD("Out", "DOut");
}
if (HasOutputs("DDOut") && ctx->HasOutput("DDOut")) {
if (ctx->HasOutput("DDOut")) {
ctx->ShareDim("Out", "DDOut");
ctx->ShareLoD("Out", "DDOut");
}
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return GetKernelType(ctx, *this, "DDX");
}
};
template <ActBwdOpFwdDeps kDepValue>
class ActivationOpDoubleGrad2 : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
if (static_cast<int>(kDepValue) & static_cast<int>(kDepX)) {
if (ctx->HasOutput("DDOut")) {
ctx->ShareDim("X", "DDOut");
ctx->ShareLoD("X", "DDOut");
}
}
if (static_cast<int>(kDepValue) & static_cast<int>(kDepOut)) {
if (ctx->HasOutput("DDOut")) {
ctx->ShareDim("Out", "DDOut");
ctx->ShareLoD("Out", "DDOut");
}
......@@ -771,7 +798,7 @@ REGISTER_OPERATOR(relu_grad, ops::ActivationOpGrad,
ops::ReluDoubleGradMaker);
REGISTER_OPERATOR(
relu_grad_grad,
ops::ActivationOpDoubleGrad<ops::ReluGradFunctor<float>::FwdDeps()>);
ops::ActivationOpDoubleGrad2<ops::ReluGradFunctor<float>::FwdDeps()>);
REGISTER_ACTIVATION_CPU_KERNEL(relu, Relu, ReluFunctor, ReluGradFunctor);
......@@ -796,7 +823,7 @@ REGISTER_OPERATOR(leaky_relu_grad, ops::ActivationOpGrad,
ops::LeakyReluDoubleGradMaker);
REGISTER_OPERATOR(
leaky_relu_grad_grad,
ops::ActivationOpDoubleGrad<ops::LeakyReluGradFunctor<float>::FwdDeps()>);
ops::ActivationOpDoubleGrad2<ops::LeakyReluGradFunctor<float>::FwdDeps()>);
REGISTER_ACTIVATION_CPU_KERNEL(leaky_relu, LeakyRelu, LeakyReluFunctor,
LeakyReluGradFunctor);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
新手
引导
客服 返回
顶部