未验证 提交 3779e807 编写于 作者: C crystal 提交者: GitHub

move gather_tree infer shape (#40082)

上级 00bbb8c5
...@@ -12,7 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -21,20 +24,6 @@ class GatherTreeOp : public framework::OperatorWithKernel { ...@@ -21,20 +24,6 @@ class GatherTreeOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Ids"), "Input", "Ids", "GatherTree");
OP_INOUT_CHECK(ctx->HasInput("Parents"), "Input", "Parents", "GatherTree");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "GatherTree");
auto ids_dims = ctx->GetInputDim("Ids");
auto parents_dims = ctx->GetInputDim("Parents");
PADDLE_ENFORCE_EQ(ids_dims == parents_dims, true,
platform::errors::InvalidArgument(
"The shape of Input(Parents) must be same with the "
"shape of Input(Ids)."));
ctx->SetOutputDim("Out", ids_dims);
}
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
...@@ -72,4 +61,8 @@ selected ids. ...@@ -72,4 +61,8 @@ selected ids.
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(gather_tree, ops::GatherTreeOp, ops::GatherTreeOpMaker); DELCARE_INFER_SHAPE_FUNCTOR(gather_tree, GatherTreeInferShapeFunctor,
PT_INFER_META(phi::GatherTreeMeta));
REGISTER_OPERATOR(gather_tree, ops::GatherTreeOp, ops::GatherTreeOpMaker,
GatherTreeInferShapeFunctor);
...@@ -348,4 +348,17 @@ void BCELossInferMeta(const MetaTensor& input, ...@@ -348,4 +348,17 @@ void BCELossInferMeta(const MetaTensor& input,
out->share_lod(input); out->share_lod(input);
} }
void GatherTreeMeta(const MetaTensor& ids,
const MetaTensor& parents,
MetaTensor* out) {
auto ids_dims = ids.dims();
auto parents_dims = parents.dims();
PADDLE_ENFORCE_EQ(ids_dims == parents_dims,
true,
phi::errors::InvalidArgument(
"The shape of Input(Parents) must be same with the "
"shape of Input(Ids)."));
out->set_dims(ids_dims);
}
} // namespace phi } // namespace phi
...@@ -68,4 +68,8 @@ void BCELossInferMeta(const MetaTensor& input, ...@@ -68,4 +68,8 @@ void BCELossInferMeta(const MetaTensor& input,
const MetaTensor& label, const MetaTensor& label,
MetaTensor* out, MetaTensor* out,
MetaConfig config = MetaConfig()); MetaConfig config = MetaConfig());
void GatherTreeMeta(const MetaTensor& ids,
const MetaTensor& parents,
MetaTensor* out);
} // namespace phi } // namespace phi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册