diff --git a/paddle/fluid/framework/ir/mkldnn/mkldnn_inplace_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/mkldnn_inplace_pass_tester.cc index 796aa4039c9e851fc31bde77446467446adb9268..d578ada0db00fed85f7b4f25f1483169c72c2c0b 100644 --- a/paddle/fluid/framework/ir/mkldnn/mkldnn_inplace_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/mkldnn_inplace_pass_tester.cc @@ -15,8 +15,9 @@ #include "paddle/fluid/framework/ir/mkldnn/mkldnn_inplace_pass.h" #include -#include #include + +#include #include "paddle/fluid/framework/ir/pass_tester_helper.h" #include "paddle/fluid/framework/op_registry.h" diff --git a/paddle/fluid/operators/graph_send_recv_op.cc b/paddle/fluid/operators/graph_send_recv_op.cc index b759345eda5657f4803d8208d880c9b3977ea937..f7c006dbcb1a9a23ec619c8d790df8a093530eee 100644 --- a/paddle/fluid/operators/graph_send_recv_op.cc +++ b/paddle/fluid/operators/graph_send_recv_op.cc @@ -12,7 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/ternary.h" namespace paddle { namespace operators { @@ -21,59 +24,6 @@ class GraphSendRecvOP : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "GraphSendRecv"); - OP_INOUT_CHECK(ctx->HasInput("Src_index"), "Input", "Src_index", - "GraphSendRecv"); - OP_INOUT_CHECK(ctx->HasInput("Dst_index"), "Input", "Dst_index", - "GraphSendRecv"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "GraphSendRecv"); - - auto src_index_dims = ctx->GetInputDim("Src_index"); - if (src_index_dims.size() == 2) { - PADDLE_ENFORCE_EQ(src_index_dims[1], 1, - platform::errors::InvalidArgument( - "The last dim of Src_index should be 1 when it " - "is 2D, but we get %d", - src_index_dims[1])); - } else { - PADDLE_ENFORCE_EQ( - src_index_dims.size(), 1, - platform::errors::InvalidArgument( - "The Src_index should be 1D, when it is not 2D, but we get %d", - src_index_dims.size())); - } - - auto dst_index_dims = ctx->GetInputDim("Dst_index"); - if (dst_index_dims.size() == 2) { - PADDLE_ENFORCE_EQ(dst_index_dims[1], 1, - platform::errors::InvalidArgument( - "The last dim of Dst_index should be 1 when it " - "is 2D, but we get %d", - dst_index_dims[1])); - } else { - PADDLE_ENFORCE_EQ( - dst_index_dims.size(), 1, - platform::errors::InvalidArgument("The Dst_index should be 1D, " - "when it is not 2D, but we get %d", - dst_index_dims.size())); - } - - PADDLE_ENFORCE_EQ( - src_index_dims[0], dst_index_dims[0], - platform::errors::InvalidArgument( - "Src_index and Dst_index should have the same shape.")); - - auto dims = ctx->GetInputDim("X"); - ctx->SetOutputDim("Out", dims); - - if (ctx->Attrs().Get("pool_type") == "MEAN") { - OP_INOUT_CHECK(ctx->HasOutput("Dst_count"), "Output", "Dst_count", - "GraphSendRecv"); - ctx->SetOutputDim("Dst_count", {dims[0]}); - } - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { @@ -164,10 +114,12 @@ class GraphSendRecvGradOpMaker : public framework::SingleGradOpMaker { } // namespace paddle namespace ops = paddle::operators; -using CPU = paddle::platform::CPUDeviceContext; +DECLARE_INFER_SHAPE_FUNCTOR(graph_send_recv, GraphSendRecvInferShapeFunctor, + PD_INFER_META(phi::GraphSendRecvInferMeta)); REGISTER_OPERATOR(graph_send_recv, ops::GraphSendRecvOP, ops::GraphSendRecvOpMaker, ops::GraphSendRecvGradOpMaker, - ops::GraphSendRecvGradOpMaker); + ops::GraphSendRecvGradOpMaker, + GraphSendRecvInferShapeFunctor); REGISTER_OPERATOR(graph_send_recv_grad, ops::GraphSendRecvGradOp); diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index 67a82392411bd9393f6f8e0129f846e06910b21c..8baf3d7ed965b624585220d61db241cce13664ae 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -285,4 +285,58 @@ void LinspaceInferMeta(const MetaTensor& start, out->set_dtype(start.dtype()); } +void GraphSendRecvInferMeta(const MetaTensor& x, + const MetaTensor& src_index, + const MetaTensor& dst_index, + const std::string& pool_type, + MetaTensor* out, + MetaTensor* dst_count) { + auto src_index_dims = src_index.dims(); + if (src_index_dims.size() == 2) { + PADDLE_ENFORCE_EQ(src_index_dims[1], + 1, + phi::errors::InvalidArgument( + "The last dim of Src_index should be 1 when it " + "is 2D, but we get %d", + src_index_dims[1])); + } else { + PADDLE_ENFORCE_EQ( + src_index_dims.size(), + 1, + phi::errors::InvalidArgument( + "The Src_index should be 1D, when it is not 2D, but we get %d", + src_index_dims.size())); + } + + auto dst_index_dims = dst_index.dims(); + if (dst_index_dims.size() == 2) { + PADDLE_ENFORCE_EQ(dst_index_dims[1], + 1, + phi::errors::InvalidArgument( + "The last dim of Dst_index should be 1 when it " + "is 2D, but we get %d", + dst_index_dims[1])); + } else { + PADDLE_ENFORCE_EQ( + dst_index_dims.size(), + 1, + phi::errors::InvalidArgument("The Dst_index should be 1D, " + "when it is not 2D, but we get %d", + dst_index_dims.size())); + } + + PADDLE_ENFORCE_EQ(src_index_dims[0], + dst_index_dims[0], + phi::errors::InvalidArgument( + "Src_index and Dst_index should have the same shape.")); + + auto dims = x.dims(); + out->set_dims(dims); + out->set_dtype(x.dtype()); + + if (pool_type == "MEAN") { + dst_count->set_dims({dims[0]}); + dst_count->set_dtype(DataType::INT32); + } +} } // namespace phi diff --git a/paddle/phi/infermeta/ternary.h b/paddle/phi/infermeta/ternary.h index da48641dee7d5b3e4aa91dc2858d247382e97b98..b54460bc9f68dd2a9821a135c27c50135cff85a2 100644 --- a/paddle/phi/infermeta/ternary.h +++ b/paddle/phi/infermeta/ternary.h @@ -71,4 +71,10 @@ void LinspaceInferMeta(const MetaTensor& start, const MetaTensor& number, MetaTensor* out); +void GraphSendRecvInferMeta(const MetaTensor& x, + const MetaTensor& src_index, + const MetaTensor& dst_index, + const std::string& pool_type, + MetaTensor* out, + MetaTensor* dst_count); } // namespace phi