未验证 提交 3f653c83 编写于 作者: L Leo Chen 提交者: GitHub

register NoNeedBufferVarsInference for max_pool_grad_op, test=develop (#22055)

* fix test_conv2d_ngraph for grad diff, test=develop

* register NoNeedBufferVarsInference for max_pool_grad_op, test=develop

* refine error message, test=develop

* fix numpy, test=develop

* disable test conv2d_ngraph_op, test=develop
Co-authored-by: NZhang Ting <709968123@qq.com>
上级 5b883789
...@@ -88,19 +88,26 @@ class MaxPoolWithIndexOpGrad : public framework::OperatorWithKernel { ...@@ -88,19 +88,26 @@ class MaxPoolWithIndexOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Mask"), "Input(Mask) must not be null."); PADDLE_ENFORCE_EQ(
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null."); ctx->HasInput("Mask"), true,
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), platform::errors::NotFound("Input(Mask) must not be null."));
"Input(X@GRAD) should not be null."); PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::NotFound("Input(X) must not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput(framework::GradVarName("Out")), true,
platform::errors::NotFound("Input(Out@GRAD) should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput(framework::GradVarName("X")), true,
platform::errors::NotFound("Output(X@GRAD) should not be null."));
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
} }
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx, framework::GradVarName("Out")),
ctx.device_context()); ctx.device_context());
} }
}; };
...@@ -302,6 +309,9 @@ class MaxPoolWithIndexGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -302,6 +309,9 @@ class MaxPoolWithIndexGradOpMaker : public framework::SingleGradOpMaker<T> {
} }
}; };
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(
MaxPoolWithIndexOpGradNoNeedBufferVarsInference, "X");
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -311,7 +321,8 @@ REGISTER_OPERATOR(max_pool2d_with_index, ops::MaxPoolWithIndexOp, ...@@ -311,7 +321,8 @@ REGISTER_OPERATOR(max_pool2d_with_index, ops::MaxPoolWithIndexOp,
ops::MaxPool2dWithIndexOpMaker, ops::MaxPool2dWithIndexOpMaker,
ops::MaxPoolWithIndexGradOpMaker<paddle::framework::OpDesc>, ops::MaxPoolWithIndexGradOpMaker<paddle::framework::OpDesc>,
ops::MaxPoolWithIndexGradOpMaker<paddle::imperative::OpBase>); ops::MaxPoolWithIndexGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(max_pool2d_with_index_grad, ops::MaxPoolWithIndexOpGrad); REGISTER_OPERATOR(max_pool2d_with_index_grad, ops::MaxPoolWithIndexOpGrad,
ops::MaxPoolWithIndexOpGradNoNeedBufferVarsInference);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
max_pool2d_with_index, max_pool2d_with_index,
...@@ -329,7 +340,8 @@ REGISTER_OPERATOR(max_pool3d_with_index, ops::MaxPoolWithIndexOp, ...@@ -329,7 +340,8 @@ REGISTER_OPERATOR(max_pool3d_with_index, ops::MaxPoolWithIndexOp,
ops::MaxPool3dWithIndexOpMaker, ops::MaxPool3dWithIndexOpMaker,
ops::MaxPoolWithIndexGradOpMaker<paddle::framework::OpDesc>, ops::MaxPoolWithIndexGradOpMaker<paddle::framework::OpDesc>,
ops::MaxPoolWithIndexGradOpMaker<paddle::imperative::OpBase>); ops::MaxPoolWithIndexGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(max_pool3d_with_index_grad, ops::MaxPoolWithIndexOpGrad); REGISTER_OPERATOR(max_pool3d_with_index_grad, ops::MaxPoolWithIndexOpGrad,
ops::MaxPoolWithIndexOpGradNoNeedBufferVarsInference);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
max_pool3d_with_index, max_pool3d_with_index,
......
file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
list(REMOVE_ITEM TEST_OPS test_conv2d_ngraph_op)
foreach(TEST_OP ${TEST_OPS}) foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS FLAGS_use_ngraph=true) py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS FLAGS_use_ngraph=true)
endforeach(TEST_OP) endforeach(TEST_OP)
...@@ -17,36 +17,42 @@ from __future__ import print_function ...@@ -17,36 +17,42 @@ from __future__ import print_function
import unittest, sys import unittest, sys
sys.path.append("../") sys.path.append("../")
from test_conv2d_op import TestConv2dOp, TestWithPad, TestWithStride, TestWithGroup, TestWith1x1, TestWithInput1x1Filter1x1, TestDepthwiseConv, TestDepthwiseConv2, TestDepthwiseConv3, TestDepthwiseConvWithDilation, TestDepthwiseConvWithDilation2 from test_conv2d_op import TestConv2dOp, TestWithPad, TestWithStride, TestWithGroup, TestWith1x1, TestWithInput1x1Filter1x1, TestDepthwiseConv, TestDepthwiseConv2, TestDepthwiseConv3, TestDepthwiseConvWithDilation, TestDepthwiseConvWithDilation2
import numpy as np
class TestNGRAPHDepthwiseConv(TestDepthwiseConv): class TestNGRAPHDepthwiseConv(TestDepthwiseConv):
def init_test_case(self): def init_test_case(self):
super(TestNGRAPHDepthwiseConv, self).init_test_case() super(TestNGRAPHDepthwiseConv, self).init_test_case()
self.use_cuda = False self.use_cuda = False
self.dtype = np.float32
class TestNGRAPHDepthwiseConv2(TestDepthwiseConv2): class TestNGRAPHDepthwiseConv2(TestDepthwiseConv2):
def init_test_case(self): def init_test_case(self):
super(TestNGRAPHDepthwiseConv2, self).init_test_case() super(TestNGRAPHDepthwiseConv2, self).init_test_case()
self.use_cuda = False self.use_cuda = False
self.dtype = np.float32
class TestNGRAPHDepthwiseConv3(TestDepthwiseConv3): class TestNGRAPHDepthwiseConv3(TestDepthwiseConv3):
def init_test_case(self): def init_test_case(self):
super(TestNGRAPHDepthwiseConv3, self).init_test_case() super(TestNGRAPHDepthwiseConv3, self).init_test_case()
self.use_cuda = False self.use_cuda = False
self.dtype = np.float32
class TestNGRAPHDepthwiseConvWithDilation(TestDepthwiseConvWithDilation): class TestNGRAPHDepthwiseConvWithDilation(TestDepthwiseConvWithDilation):
def init_test_case(self): def init_test_case(self):
super(TestNGRAPHDepthwiseConvWithDilation, self).init_test_case() super(TestNGRAPHDepthwiseConvWithDilation, self).init_test_case()
self.use_cuda = False self.use_cuda = False
self.dtype = np.float32
class TestNGRAPHDepthwiseConvWithDilation2(TestDepthwiseConvWithDilation2): class TestNGRAPHDepthwiseConvWithDilation2(TestDepthwiseConvWithDilation2):
def init_test_case(self): def init_test_case(self):
super(TestNGRAPHDepthwiseConvWithDilation2, self).init_test_case() super(TestNGRAPHDepthwiseConvWithDilation2, self).init_test_case()
self.use_cuda = False self.use_cuda = False
self.dtype = np.float32
del TestDepthwiseConv, TestDepthwiseConv2, TestDepthwiseConv3, TestDepthwiseConvWithDilation, TestDepthwiseConvWithDilation2 del TestDepthwiseConv, TestDepthwiseConv2, TestDepthwiseConv3, TestDepthwiseConvWithDilation, TestDepthwiseConvWithDilation2
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册