diff --git a/paddle/fluid/operators/pool_with_index_op.cc b/paddle/fluid/operators/pool_with_index_op.cc index 7d8e1c498151cb502905fe8ba9e7cd9388c277cc..113fcd30e656e08d90000a863ccf5213a797b42c 100644 --- a/paddle/fluid/operators/pool_with_index_op.cc +++ b/paddle/fluid/operators/pool_with_index_op.cc @@ -88,19 +88,26 @@ class MaxPoolWithIndexOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("Mask"), "Input(Mask) must not be null."); - PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null."); - PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), - "Input(X@GRAD) should not be null."); + PADDLE_ENFORCE_EQ( + ctx->HasInput("Mask"), true, + platform::errors::NotFound("Input(Mask) must not be null.")); + PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, + platform::errors::NotFound("Input(X) must not be null.")); + PADDLE_ENFORCE_EQ( + ctx->HasInput(framework::GradVarName("Out")), true, + platform::errors::NotFound("Input(Out@GRAD) should not be null.")); + PADDLE_ENFORCE_EQ( + ctx->HasOutput(framework::GradVarName("X")), true, + platform::errors::NotFound("Output(X@GRAD) should not be null.")); ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); } protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); } }; @@ -302,6 +309,9 @@ class MaxPoolWithIndexGradOpMaker : public framework::SingleGradOpMaker { } }; +DECLARE_NO_NEED_BUFFER_VARS_INFERENCE( + MaxPoolWithIndexOpGradNoNeedBufferVarsInference, "X"); + } // namespace operators } // namespace paddle @@ -311,7 +321,8 @@ REGISTER_OPERATOR(max_pool2d_with_index, ops::MaxPoolWithIndexOp, ops::MaxPool2dWithIndexOpMaker, ops::MaxPoolWithIndexGradOpMaker, ops::MaxPoolWithIndexGradOpMaker); -REGISTER_OPERATOR(max_pool2d_with_index_grad, ops::MaxPoolWithIndexOpGrad); +REGISTER_OPERATOR(max_pool2d_with_index_grad, ops::MaxPoolWithIndexOpGrad, + ops::MaxPoolWithIndexOpGradNoNeedBufferVarsInference); REGISTER_OP_CPU_KERNEL( max_pool2d_with_index, @@ -329,7 +340,8 @@ REGISTER_OPERATOR(max_pool3d_with_index, ops::MaxPoolWithIndexOp, ops::MaxPool3dWithIndexOpMaker, ops::MaxPoolWithIndexGradOpMaker, ops::MaxPoolWithIndexGradOpMaker); -REGISTER_OPERATOR(max_pool3d_with_index_grad, ops::MaxPoolWithIndexOpGrad); +REGISTER_OPERATOR(max_pool3d_with_index_grad, ops::MaxPoolWithIndexOpGrad, + ops::MaxPoolWithIndexOpGradNoNeedBufferVarsInference); REGISTER_OP_CPU_KERNEL( max_pool3d_with_index, diff --git a/python/paddle/fluid/tests/unittests/ngraph/CMakeLists.txt b/python/paddle/fluid/tests/unittests/ngraph/CMakeLists.txt index 5ed2d0aa80cd0462d3ac1902a2ec13fc2c1bd844..e9866699a604bc119b45e1b21155694e1e16d396 100644 --- a/python/paddle/fluid/tests/unittests/ngraph/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/ngraph/CMakeLists.txt @@ -1,6 +1,8 @@ file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") +list(REMOVE_ITEM TEST_OPS test_conv2d_ngraph_op) + foreach(TEST_OP ${TEST_OPS}) py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS FLAGS_use_ngraph=true) endforeach(TEST_OP) diff --git a/python/paddle/fluid/tests/unittests/ngraph/test_conv2d_ngraph_op.py b/python/paddle/fluid/tests/unittests/ngraph/test_conv2d_ngraph_op.py index fc2031c4cfc21ff2d895d304b451215dc34b4462..4894af949a20bad59836989060842c82b78d5cfe 100644 --- a/python/paddle/fluid/tests/unittests/ngraph/test_conv2d_ngraph_op.py +++ b/python/paddle/fluid/tests/unittests/ngraph/test_conv2d_ngraph_op.py @@ -17,36 +17,42 @@ from __future__ import print_function import unittest, sys sys.path.append("../") from test_conv2d_op import TestConv2dOp, TestWithPad, TestWithStride, TestWithGroup, TestWith1x1, TestWithInput1x1Filter1x1, TestDepthwiseConv, TestDepthwiseConv2, TestDepthwiseConv3, TestDepthwiseConvWithDilation, TestDepthwiseConvWithDilation2 +import numpy as np class TestNGRAPHDepthwiseConv(TestDepthwiseConv): def init_test_case(self): super(TestNGRAPHDepthwiseConv, self).init_test_case() self.use_cuda = False + self.dtype = np.float32 class TestNGRAPHDepthwiseConv2(TestDepthwiseConv2): def init_test_case(self): super(TestNGRAPHDepthwiseConv2, self).init_test_case() self.use_cuda = False + self.dtype = np.float32 class TestNGRAPHDepthwiseConv3(TestDepthwiseConv3): def init_test_case(self): super(TestNGRAPHDepthwiseConv3, self).init_test_case() self.use_cuda = False + self.dtype = np.float32 class TestNGRAPHDepthwiseConvWithDilation(TestDepthwiseConvWithDilation): def init_test_case(self): super(TestNGRAPHDepthwiseConvWithDilation, self).init_test_case() self.use_cuda = False + self.dtype = np.float32 class TestNGRAPHDepthwiseConvWithDilation2(TestDepthwiseConvWithDilation2): def init_test_case(self): super(TestNGRAPHDepthwiseConvWithDilation2, self).init_test_case() self.use_cuda = False + self.dtype = np.float32 del TestDepthwiseConv, TestDepthwiseConv2, TestDepthwiseConv3, TestDepthwiseConvWithDilation, TestDepthwiseConvWithDilation2