未验证 提交 1c9b2483 编写于 作者: C Chen Weihang 提交者: GitHub

move trace infer shape (#39517)

上级 5fb9cf60
......@@ -12,8 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/pten/core/infermeta_utils.h"
#include "paddle/pten/infermeta/unary.h"
namespace paddle {
namespace operators {
......@@ -21,57 +24,6 @@ namespace operators {
class TraceOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("Input"), true,
platform::errors::NotFound("Input of TraceOp is not found."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"), true,
platform::errors::NotFound("Output of TraceOp is not found."));
int dim1 = ctx->Attrs().Get<int>("axis1");
int dim2 = ctx->Attrs().Get<int>("axis2");
auto x_dims = ctx->GetInputDim("Input");
int dim1_ = dim1 < 0 ? x_dims.size() + dim1 : dim1;
int dim2_ = dim2 < 0 ? x_dims.size() + dim2 : dim2;
PADDLE_ENFORCE_GE(
x_dims.size(), 2,
platform::errors::OutOfRange(
"Input's dim is out of range (expected at least 2, but got %ld).",
x_dims.size()));
PADDLE_ENFORCE_LT(
dim1_, x_dims.size(),
platform::errors::OutOfRange(
"Attr(dim1) is out of range (expected to be in range of [%ld, "
"%ld], but got %ld).",
-(x_dims.size()), (x_dims.size() - 1), dim1));
PADDLE_ENFORCE_LT(
dim2_, x_dims.size(),
platform::errors::OutOfRange(
"Attr(dim2) is out of range (expected to be in range of [%ld, "
"%ld], but got %ld).",
-(x_dims.size()), (x_dims.size() - 1), dim2));
PADDLE_ENFORCE_NE(dim1_, dim2_,
platform::errors::InvalidArgument(
"The dimensions should not be identical "
"%ld vs %ld.",
dim1, dim2));
auto sizes = vectorize(x_dims);
if (x_dims.size() == 2) {
sizes.clear();
sizes.push_back(1);
} else {
sizes.erase(sizes.begin() + std::max(dim1_, dim2_));
sizes.erase(sizes.begin() + std::min(dim1_, dim2_));
}
ctx->SetOutputDim("Out", framework::make_ddim(sizes));
}
};
class TraceOpMaker : public framework::OpProtoAndCheckerMaker {
......@@ -155,9 +107,12 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(TraceGradNoNeedBufferVarsInferer, "Input");
} // namespace paddle
namespace ops = paddle::operators;
DELCARE_INFER_SHAPE_FUNCTOR(trace, TraceInferShapeFunctor,
PT_INFER_META(pten::TraceInferMeta));
REGISTER_OPERATOR(trace, ops::TraceOp, ops::TraceOpMaker,
ops::TraceGradOpMaker<paddle::framework::OpDesc>,
ops::TraceGradOpMaker<paddle::imperative::OpBase>);
ops::TraceGradOpMaker<paddle::imperative::OpBase>,
TraceInferShapeFunctor);
REGISTER_OPERATOR(trace_grad, ops::TraceOpGrad,
ops::TraceGradNoNeedBufferVarsInferer);
......
......@@ -444,8 +444,59 @@ void SplitInferMeta(const MetaTensor& x,
(*out)[i].share_lod(x);
}
}
}
void TraceInferMeta(
const MetaTensor& x, int offset, int axis1, int axis2, MetaTensor* out) {
int dim1 = axis1;
int dim2 = axis2;
auto x_dims = x.dims();
int dim1_ = dim1 < 0 ? x_dims.size() + dim1 : dim1;
int dim2_ = dim2 < 0 ? x_dims.size() + dim2 : dim2;
return;
PADDLE_ENFORCE_GE(
x_dims.size(),
2,
pten::errors::OutOfRange(
"Input's dim is out of range (expected at least 2, but got %ld).",
x_dims.size()));
PADDLE_ENFORCE_LT(
dim1_,
x_dims.size(),
pten::errors::OutOfRange(
"Attr(dim1) is out of range (expected to be in range of [%ld, "
"%ld], but got %ld).",
-(x_dims.size()),
(x_dims.size() - 1),
dim1));
PADDLE_ENFORCE_LT(
dim2_,
x_dims.size(),
pten::errors::OutOfRange(
"Attr(dim2) is out of range (expected to be in range of [%ld, "
"%ld], but got %ld).",
-(x_dims.size()),
(x_dims.size() - 1),
dim2));
PADDLE_ENFORCE_NE(
dim1_,
dim2_,
pten::errors::InvalidArgument("The dimensions should not be identical "
"%ld vs %ld.",
dim1,
dim2));
auto sizes = vectorize(x_dims);
if (x_dims.size() == 2) {
sizes.clear();
sizes.push_back(1);
} else {
sizes.erase(sizes.begin() + std::max(dim1_, dim2_));
sizes.erase(sizes.begin() + std::min(dim1_, dim2_));
}
out->set_dims(framework::make_ddim(sizes));
}
} // namespace pten
......@@ -80,4 +80,8 @@ void SplitInferMeta(const MetaTensor& x_meta,
const Scalar& axis,
std::vector<MetaTensor>* out,
MetaConfig config = MetaConfig());
void TraceInferMeta(
const MetaTensor& x, int offset, int axis1, int axis2, MetaTensor* out);
} // namespace pten
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册