提交 6a5545aa 编写于 作者: P phlrain

fix squeeze shape check; test=develop

上级 190cfd69
......@@ -40,7 +40,7 @@ class SqueezeOpInferShape : public framework::InferShapeBase {
"tensor's rank.");
}
auto out_dims = GetOutputShape(axes, x_dims, ctx);
auto out_dims = GetOutputShape(axes, x_dims, false);
ctx->SetOutputDim("Out", out_dims);
if (x_dims[0] == out_dims[0]) {
// Only pass LoD when the first dimension of output and Input(X)
......@@ -51,7 +51,7 @@ class SqueezeOpInferShape : public framework::InferShapeBase {
static framework::DDim GetOutputShape(const std::vector<int> squeeze_dims,
const framework::DDim &in_dims,
framework::InferShapeContext *ctx) {
bool is_runtime) {
size_t num_squeeze_dims = squeeze_dims.size();
int cnt_squeezed_dims = 0;
bool should_squeeze[9] = {false};
......@@ -73,7 +73,7 @@ class SqueezeOpInferShape : public framework::InferShapeBase {
PADDLE_ENFORCE(current >= 0,
"Invalid axis, the negative axis is out of range.");
if (ctx->IsRuntime()) {
if (is_runtime) {
PADDLE_ENFORCE(in_dims[current] == 1,
"Invalid axis index, the axis that will be squeezed "
"should be equal to 1.");
......@@ -108,7 +108,7 @@ class SqueezeOp : public framework::OperatorBase {
const platform::Place &place) const override {
auto &axes = Attr<std::vector<int>>("axes");
auto x_dims = scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims();
auto out_dims = SqueezeOpInferShape::GetOutputShape(axes, x_dims);
auto out_dims = SqueezeOpInferShape::GetOutputShape(axes, x_dims, true);
framework::AttributeMap attrs;
attrs["shape"] = framework::vectorize2int(out_dims);
......@@ -228,7 +228,7 @@ class Squeeze2Op : public framework::OperatorBase {
const platform::Place &place) const override {
auto &axes = Attr<std::vector<int>>("axes");
auto x_dims = scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims();
auto out_dims = Squeeze2OpInferShape::GetOutputShape(axes, x_dims);
auto out_dims = Squeeze2OpInferShape::GetOutputShape(axes, x_dims, true);
framework::AttributeMap attrs;
attrs["shape"] = framework::vectorize2int(out_dims);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册