未验证 提交 5ae85131 编写于 作者: W wawltor 提交者: GitHub

[Phi] add the infer shape meta for the graph_send_recv (#40320)

* add the infer shape meta for the graph_send_recv

* move the infershape code to another file
上级 cf9291b9
......@@ -15,8 +15,9 @@
#include "paddle/fluid/framework/ir/mkldnn/mkldnn_inplace_pass.h"
#include <gtest/gtest.h>
#include <boost/logic/tribool.hpp>
#include <unordered_set>
#include <boost/logic/tribool.hpp>
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/op_registry.h"
......
......@@ -12,7 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/ternary.h"
namespace paddle {
namespace operators {
......@@ -21,59 +24,6 @@ class GraphSendRecvOP : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "GraphSendRecv");
OP_INOUT_CHECK(ctx->HasInput("Src_index"), "Input", "Src_index",
"GraphSendRecv");
OP_INOUT_CHECK(ctx->HasInput("Dst_index"), "Input", "Dst_index",
"GraphSendRecv");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "GraphSendRecv");
auto src_index_dims = ctx->GetInputDim("Src_index");
if (src_index_dims.size() == 2) {
PADDLE_ENFORCE_EQ(src_index_dims[1], 1,
platform::errors::InvalidArgument(
"The last dim of Src_index should be 1 when it "
"is 2D, but we get %d",
src_index_dims[1]));
} else {
PADDLE_ENFORCE_EQ(
src_index_dims.size(), 1,
platform::errors::InvalidArgument(
"The Src_index should be 1D, when it is not 2D, but we get %d",
src_index_dims.size()));
}
auto dst_index_dims = ctx->GetInputDim("Dst_index");
if (dst_index_dims.size() == 2) {
PADDLE_ENFORCE_EQ(dst_index_dims[1], 1,
platform::errors::InvalidArgument(
"The last dim of Dst_index should be 1 when it "
"is 2D, but we get %d",
dst_index_dims[1]));
} else {
PADDLE_ENFORCE_EQ(
dst_index_dims.size(), 1,
platform::errors::InvalidArgument("The Dst_index should be 1D, "
"when it is not 2D, but we get %d",
dst_index_dims.size()));
}
PADDLE_ENFORCE_EQ(
src_index_dims[0], dst_index_dims[0],
platform::errors::InvalidArgument(
"Src_index and Dst_index should have the same shape."));
auto dims = ctx->GetInputDim("X");
ctx->SetOutputDim("Out", dims);
if (ctx->Attrs().Get<std::string>("pool_type") == "MEAN") {
OP_INOUT_CHECK(ctx->HasOutput("Dst_count"), "Output", "Dst_count",
"GraphSendRecv");
ctx->SetOutputDim("Dst_count", {dims[0]});
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
......@@ -164,10 +114,12 @@ class GraphSendRecvGradOpMaker : public framework::SingleGradOpMaker<T> {
} // namespace paddle
namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext;
DECLARE_INFER_SHAPE_FUNCTOR(graph_send_recv, GraphSendRecvInferShapeFunctor,
PD_INFER_META(phi::GraphSendRecvInferMeta));
REGISTER_OPERATOR(graph_send_recv, ops::GraphSendRecvOP,
ops::GraphSendRecvOpMaker,
ops::GraphSendRecvGradOpMaker<paddle::framework::OpDesc>,
ops::GraphSendRecvGradOpMaker<paddle::imperative::OpBase>);
ops::GraphSendRecvGradOpMaker<paddle::imperative::OpBase>,
GraphSendRecvInferShapeFunctor);
REGISTER_OPERATOR(graph_send_recv_grad, ops::GraphSendRecvGradOp);
......@@ -285,4 +285,58 @@ void LinspaceInferMeta(const MetaTensor& start,
out->set_dtype(start.dtype());
}
void GraphSendRecvInferMeta(const MetaTensor& x,
const MetaTensor& src_index,
const MetaTensor& dst_index,
const std::string& pool_type,
MetaTensor* out,
MetaTensor* dst_count) {
auto src_index_dims = src_index.dims();
if (src_index_dims.size() == 2) {
PADDLE_ENFORCE_EQ(src_index_dims[1],
1,
phi::errors::InvalidArgument(
"The last dim of Src_index should be 1 when it "
"is 2D, but we get %d",
src_index_dims[1]));
} else {
PADDLE_ENFORCE_EQ(
src_index_dims.size(),
1,
phi::errors::InvalidArgument(
"The Src_index should be 1D, when it is not 2D, but we get %d",
src_index_dims.size()));
}
auto dst_index_dims = dst_index.dims();
if (dst_index_dims.size() == 2) {
PADDLE_ENFORCE_EQ(dst_index_dims[1],
1,
phi::errors::InvalidArgument(
"The last dim of Dst_index should be 1 when it "
"is 2D, but we get %d",
dst_index_dims[1]));
} else {
PADDLE_ENFORCE_EQ(
dst_index_dims.size(),
1,
phi::errors::InvalidArgument("The Dst_index should be 1D, "
"when it is not 2D, but we get %d",
dst_index_dims.size()));
}
PADDLE_ENFORCE_EQ(src_index_dims[0],
dst_index_dims[0],
phi::errors::InvalidArgument(
"Src_index and Dst_index should have the same shape."));
auto dims = x.dims();
out->set_dims(dims);
out->set_dtype(x.dtype());
if (pool_type == "MEAN") {
dst_count->set_dims({dims[0]});
dst_count->set_dtype(DataType::INT32);
}
}
} // namespace phi
......@@ -71,4 +71,10 @@ void LinspaceInferMeta(const MetaTensor& start,
const MetaTensor& number,
MetaTensor* out);
void GraphSendRecvInferMeta(const MetaTensor& x,
const MetaTensor& src_index,
const MetaTensor& dst_index,
const std::string& pool_type,
MetaTensor* out,
MetaTensor* dst_count);
} // namespace phi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册