未验证 提交 5ae85131 编写于 作者: W wawltor 提交者: GitHub

[Phi] add the infer shape meta for the graph_send_recv (#40320)

* add the infer shape meta for the graph_send_recv

* move the infershape code to another file
上级 cf9291b9
...@@ -15,8 +15,9 @@ ...@@ -15,8 +15,9 @@
#include "paddle/fluid/framework/ir/mkldnn/mkldnn_inplace_pass.h" #include "paddle/fluid/framework/ir/mkldnn/mkldnn_inplace_pass.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <boost/logic/tribool.hpp>
#include <unordered_set> #include <unordered_set>
#include <boost/logic/tribool.hpp>
#include "paddle/fluid/framework/ir/pass_tester_helper.h" #include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
......
...@@ -12,7 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/ternary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -21,59 +24,6 @@ class GraphSendRecvOP : public framework::OperatorWithKernel { ...@@ -21,59 +24,6 @@ class GraphSendRecvOP : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "GraphSendRecv");
OP_INOUT_CHECK(ctx->HasInput("Src_index"), "Input", "Src_index",
"GraphSendRecv");
OP_INOUT_CHECK(ctx->HasInput("Dst_index"), "Input", "Dst_index",
"GraphSendRecv");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "GraphSendRecv");
auto src_index_dims = ctx->GetInputDim("Src_index");
if (src_index_dims.size() == 2) {
PADDLE_ENFORCE_EQ(src_index_dims[1], 1,
platform::errors::InvalidArgument(
"The last dim of Src_index should be 1 when it "
"is 2D, but we get %d",
src_index_dims[1]));
} else {
PADDLE_ENFORCE_EQ(
src_index_dims.size(), 1,
platform::errors::InvalidArgument(
"The Src_index should be 1D, when it is not 2D, but we get %d",
src_index_dims.size()));
}
auto dst_index_dims = ctx->GetInputDim("Dst_index");
if (dst_index_dims.size() == 2) {
PADDLE_ENFORCE_EQ(dst_index_dims[1], 1,
platform::errors::InvalidArgument(
"The last dim of Dst_index should be 1 when it "
"is 2D, but we get %d",
dst_index_dims[1]));
} else {
PADDLE_ENFORCE_EQ(
dst_index_dims.size(), 1,
platform::errors::InvalidArgument("The Dst_index should be 1D, "
"when it is not 2D, but we get %d",
dst_index_dims.size()));
}
PADDLE_ENFORCE_EQ(
src_index_dims[0], dst_index_dims[0],
platform::errors::InvalidArgument(
"Src_index and Dst_index should have the same shape."));
auto dims = ctx->GetInputDim("X");
ctx->SetOutputDim("Out", dims);
if (ctx->Attrs().Get<std::string>("pool_type") == "MEAN") {
OP_INOUT_CHECK(ctx->HasOutput("Dst_count"), "Output", "Dst_count",
"GraphSendRecv");
ctx->SetOutputDim("Dst_count", {dims[0]});
}
}
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
...@@ -164,10 +114,12 @@ class GraphSendRecvGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -164,10 +114,12 @@ class GraphSendRecvGradOpMaker : public framework::SingleGradOpMaker<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext;
DECLARE_INFER_SHAPE_FUNCTOR(graph_send_recv, GraphSendRecvInferShapeFunctor,
PD_INFER_META(phi::GraphSendRecvInferMeta));
REGISTER_OPERATOR(graph_send_recv, ops::GraphSendRecvOP, REGISTER_OPERATOR(graph_send_recv, ops::GraphSendRecvOP,
ops::GraphSendRecvOpMaker, ops::GraphSendRecvOpMaker,
ops::GraphSendRecvGradOpMaker<paddle::framework::OpDesc>, ops::GraphSendRecvGradOpMaker<paddle::framework::OpDesc>,
ops::GraphSendRecvGradOpMaker<paddle::imperative::OpBase>); ops::GraphSendRecvGradOpMaker<paddle::imperative::OpBase>,
GraphSendRecvInferShapeFunctor);
REGISTER_OPERATOR(graph_send_recv_grad, ops::GraphSendRecvGradOp); REGISTER_OPERATOR(graph_send_recv_grad, ops::GraphSendRecvGradOp);
...@@ -285,4 +285,58 @@ void LinspaceInferMeta(const MetaTensor& start, ...@@ -285,4 +285,58 @@ void LinspaceInferMeta(const MetaTensor& start,
out->set_dtype(start.dtype()); out->set_dtype(start.dtype());
} }
void GraphSendRecvInferMeta(const MetaTensor& x,
const MetaTensor& src_index,
const MetaTensor& dst_index,
const std::string& pool_type,
MetaTensor* out,
MetaTensor* dst_count) {
auto src_index_dims = src_index.dims();
if (src_index_dims.size() == 2) {
PADDLE_ENFORCE_EQ(src_index_dims[1],
1,
phi::errors::InvalidArgument(
"The last dim of Src_index should be 1 when it "
"is 2D, but we get %d",
src_index_dims[1]));
} else {
PADDLE_ENFORCE_EQ(
src_index_dims.size(),
1,
phi::errors::InvalidArgument(
"The Src_index should be 1D, when it is not 2D, but we get %d",
src_index_dims.size()));
}
auto dst_index_dims = dst_index.dims();
if (dst_index_dims.size() == 2) {
PADDLE_ENFORCE_EQ(dst_index_dims[1],
1,
phi::errors::InvalidArgument(
"The last dim of Dst_index should be 1 when it "
"is 2D, but we get %d",
dst_index_dims[1]));
} else {
PADDLE_ENFORCE_EQ(
dst_index_dims.size(),
1,
phi::errors::InvalidArgument("The Dst_index should be 1D, "
"when it is not 2D, but we get %d",
dst_index_dims.size()));
}
PADDLE_ENFORCE_EQ(src_index_dims[0],
dst_index_dims[0],
phi::errors::InvalidArgument(
"Src_index and Dst_index should have the same shape."));
auto dims = x.dims();
out->set_dims(dims);
out->set_dtype(x.dtype());
if (pool_type == "MEAN") {
dst_count->set_dims({dims[0]});
dst_count->set_dtype(DataType::INT32);
}
}
} // namespace phi } // namespace phi
...@@ -71,4 +71,10 @@ void LinspaceInferMeta(const MetaTensor& start, ...@@ -71,4 +71,10 @@ void LinspaceInferMeta(const MetaTensor& start,
const MetaTensor& number, const MetaTensor& number,
MetaTensor* out); MetaTensor* out);
void GraphSendRecvInferMeta(const MetaTensor& x,
const MetaTensor& src_index,
const MetaTensor& dst_index,
const std::string& pool_type,
MetaTensor* out,
MetaTensor* dst_count);
} // namespace phi } // namespace phi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册