未验证 提交 3779e807 编写于 作者: C crystal 提交者: GitHub

move gather_tree infer shape (#40082)

上级 00bbb8c5
......@@ -12,7 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle {
namespace operators {
......@@ -21,20 +24,6 @@ class GatherTreeOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Ids"), "Input", "Ids", "GatherTree");
OP_INOUT_CHECK(ctx->HasInput("Parents"), "Input", "Parents", "GatherTree");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "GatherTree");
auto ids_dims = ctx->GetInputDim("Ids");
auto parents_dims = ctx->GetInputDim("Parents");
PADDLE_ENFORCE_EQ(ids_dims == parents_dims, true,
platform::errors::InvalidArgument(
"The shape of Input(Parents) must be same with the "
"shape of Input(Ids)."));
ctx->SetOutputDim("Out", ids_dims);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
......@@ -72,4 +61,8 @@ selected ids.
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(gather_tree, ops::GatherTreeOp, ops::GatherTreeOpMaker);
DELCARE_INFER_SHAPE_FUNCTOR(gather_tree, GatherTreeInferShapeFunctor,
PT_INFER_META(phi::GatherTreeMeta));
REGISTER_OPERATOR(gather_tree, ops::GatherTreeOp, ops::GatherTreeOpMaker,
GatherTreeInferShapeFunctor);
......@@ -348,4 +348,17 @@ void BCELossInferMeta(const MetaTensor& input,
out->share_lod(input);
}
void GatherTreeMeta(const MetaTensor& ids,
const MetaTensor& parents,
MetaTensor* out) {
auto ids_dims = ids.dims();
auto parents_dims = parents.dims();
PADDLE_ENFORCE_EQ(ids_dims == parents_dims,
true,
phi::errors::InvalidArgument(
"The shape of Input(Parents) must be same with the "
"shape of Input(Ids)."));
out->set_dims(ids_dims);
}
} // namespace phi
......@@ -68,4 +68,8 @@ void BCELossInferMeta(const MetaTensor& input,
const MetaTensor& label,
MetaTensor* out,
MetaConfig config = MetaConfig());
void GatherTreeMeta(const MetaTensor& ids,
const MetaTensor& parents,
MetaTensor* out);
} // namespace phi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册