diff --git a/paddle/fluid/operators/trace_op.cc b/paddle/fluid/operators/trace_op.cc index aabad64c894df516ceeb8c3f7f753f3aa4fc70d3..6145db5f5ef63ef7af22a5a3746b3da49f8ffb95 100644 --- a/paddle/fluid/operators/trace_op.cc +++ b/paddle/fluid/operators/trace_op.cc @@ -12,8 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/pten/core/infermeta_utils.h" +#include "paddle/pten/infermeta/unary.h" namespace paddle { namespace operators { @@ -21,57 +24,6 @@ namespace operators { class TraceOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE_EQ( - ctx->HasInput("Input"), true, - platform::errors::NotFound("Input of TraceOp is not found.")); - - PADDLE_ENFORCE_EQ( - ctx->HasOutput("Out"), true, - platform::errors::NotFound("Output of TraceOp is not found.")); - - int dim1 = ctx->Attrs().Get("axis1"); - int dim2 = ctx->Attrs().Get("axis2"); - - auto x_dims = ctx->GetInputDim("Input"); - - int dim1_ = dim1 < 0 ? x_dims.size() + dim1 : dim1; - int dim2_ = dim2 < 0 ? x_dims.size() + dim2 : dim2; - - PADDLE_ENFORCE_GE( - x_dims.size(), 2, - platform::errors::OutOfRange( - "Input's dim is out of range (expected at least 2, but got %ld).", - x_dims.size())); - PADDLE_ENFORCE_LT( - dim1_, x_dims.size(), - platform::errors::OutOfRange( - "Attr(dim1) is out of range (expected to be in range of [%ld, " - "%ld], but got %ld).", - -(x_dims.size()), (x_dims.size() - 1), dim1)); - PADDLE_ENFORCE_LT( - dim2_, x_dims.size(), - platform::errors::OutOfRange( - "Attr(dim2) is out of range (expected to be in range of [%ld, " - "%ld], but got %ld).", - -(x_dims.size()), (x_dims.size() - 1), dim2)); - PADDLE_ENFORCE_NE(dim1_, dim2_, - platform::errors::InvalidArgument( - "The dimensions should not be identical " - "%ld vs %ld.", - dim1, dim2)); - - auto sizes = vectorize(x_dims); - if (x_dims.size() == 2) { - sizes.clear(); - sizes.push_back(1); - } else { - sizes.erase(sizes.begin() + std::max(dim1_, dim2_)); - sizes.erase(sizes.begin() + std::min(dim1_, dim2_)); - } - ctx->SetOutputDim("Out", framework::make_ddim(sizes)); - } }; class TraceOpMaker : public framework::OpProtoAndCheckerMaker { @@ -155,9 +107,12 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(TraceGradNoNeedBufferVarsInferer, "Input"); } // namespace paddle namespace ops = paddle::operators; +DELCARE_INFER_SHAPE_FUNCTOR(trace, TraceInferShapeFunctor, + PT_INFER_META(pten::TraceInferMeta)); REGISTER_OPERATOR(trace, ops::TraceOp, ops::TraceOpMaker, ops::TraceGradOpMaker, - ops::TraceGradOpMaker); + ops::TraceGradOpMaker, + TraceInferShapeFunctor); REGISTER_OPERATOR(trace_grad, ops::TraceOpGrad, ops::TraceGradNoNeedBufferVarsInferer); diff --git a/paddle/pten/infermeta/unary.cc b/paddle/pten/infermeta/unary.cc index ca59937399a226558c213fed5b43a2311a2f368a..ec9ba519b95ba740f7370a8547b092ed4c9acb4f 100644 --- a/paddle/pten/infermeta/unary.cc +++ b/paddle/pten/infermeta/unary.cc @@ -444,8 +444,59 @@ void SplitInferMeta(const MetaTensor& x, (*out)[i].share_lod(x); } } +} + +void TraceInferMeta( + const MetaTensor& x, int offset, int axis1, int axis2, MetaTensor* out) { + int dim1 = axis1; + int dim2 = axis2; + + auto x_dims = x.dims(); - return; + int dim1_ = dim1 < 0 ? x_dims.size() + dim1 : dim1; + int dim2_ = dim2 < 0 ? x_dims.size() + dim2 : dim2; + + PADDLE_ENFORCE_GE( + x_dims.size(), + 2, + pten::errors::OutOfRange( + "Input's dim is out of range (expected at least 2, but got %ld).", + x_dims.size())); + PADDLE_ENFORCE_LT( + dim1_, + x_dims.size(), + pten::errors::OutOfRange( + "Attr(dim1) is out of range (expected to be in range of [%ld, " + "%ld], but got %ld).", + -(x_dims.size()), + (x_dims.size() - 1), + dim1)); + PADDLE_ENFORCE_LT( + dim2_, + x_dims.size(), + pten::errors::OutOfRange( + "Attr(dim2) is out of range (expected to be in range of [%ld, " + "%ld], but got %ld).", + -(x_dims.size()), + (x_dims.size() - 1), + dim2)); + PADDLE_ENFORCE_NE( + dim1_, + dim2_, + pten::errors::InvalidArgument("The dimensions should not be identical " + "%ld vs %ld.", + dim1, + dim2)); + + auto sizes = vectorize(x_dims); + if (x_dims.size() == 2) { + sizes.clear(); + sizes.push_back(1); + } else { + sizes.erase(sizes.begin() + std::max(dim1_, dim2_)); + sizes.erase(sizes.begin() + std::min(dim1_, dim2_)); + } + out->set_dims(framework::make_ddim(sizes)); } } // namespace pten diff --git a/paddle/pten/infermeta/unary.h b/paddle/pten/infermeta/unary.h index 4c816c4adbc233e0442c2100f62ee8e62cc8f78c..5bdf1d491c6342e302bb1f0ea39f18c65749b8d6 100644 --- a/paddle/pten/infermeta/unary.h +++ b/paddle/pten/infermeta/unary.h @@ -80,4 +80,8 @@ void SplitInferMeta(const MetaTensor& x_meta, const Scalar& axis, std::vector* out, MetaConfig config = MetaConfig()); + +void TraceInferMeta( + const MetaTensor& x, int offset, int axis1, int axis2, MetaTensor* out); + } // namespace pten