diff --git a/paddle/fluid/operators/unsqueeze_op.cc b/paddle/fluid/operators/unsqueeze_op.cc index 80c0d1fefd39034b64e1bd425da95ae39e38dedc..fc849e73c579f3457852e05dec404c001b74b19e 100644 --- a/paddle/fluid/operators/unsqueeze_op.cc +++ b/paddle/fluid/operators/unsqueeze_op.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/fluid/operators/unsqueeze_op.h" +#include #include #include #include "paddle/fluid/framework/op_registry.h" @@ -19,20 +21,22 @@ limitations under the License. */ namespace paddle { namespace operators { -class UnsqueezeOpInferShape : public framework::InferShapeBase { +class UnsqueezeOp : public framework::OperatorWithKernel { public: - void operator()(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), - "Input(X) of Unsqueeze operator should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("Out"), - "Output(Out) of Unsqueeze operator should not be null."); + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, + "Input(X) of Unsqueeze operator should not be null."); + PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, + "Output(Out) of Unsqueeze operator should not be null."); const auto &axes = ctx->Attrs().Get>("axes"); const auto &x_dims = ctx->GetInputDim("X"); // Validity Check: input tensor dims (<6). - PADDLE_ENFORCE(x_dims.size() <= 6, - "Invalid dimensions, the rank of Input(X) " - "should be in the range of [1, 6] (Eigen limit)"); + PADDLE_ENFORCE_LE(x_dims.size(), 6, + "Invalid dimensions, the rank of Input(X) " + "should be in the range of [1, 6] (Eigen limit)"); auto out_dims = GetOutputShape(axes, x_dims); ctx->SetOutputDim("Out", out_dims); if (x_dims[0] == out_dims[0]) { @@ -49,15 +53,14 @@ class UnsqueezeOpInferShape : public framework::InferShapeBase { std::vector output_shape(output_size, 0); // Validity Check: rank range. - PADDLE_ENFORCE(output_size <= 6, - "The output tensor's rank should be less than 6."); + PADDLE_ENFORCE_LE(output_size, 6, + "The output tensor's rank should be less than 6."); for (int axis : unsqz_dims) { int cur = axis < 0 ? axis + cur_output_size + 1 : axis; // Vaildity Check: the axis bound - PADDLE_ENFORCE( - cur >= 0 && cur <= cur_output_size, - "The unsqueeze dims must be within range of current rank."); + PADDLE_ENFORCE_GE(cur, 0); + PADDLE_ENFORCE_LE(cur, cur_output_size); // Move old axis, and insert new axis for (int i = cur_output_size; i >= cur; --i) { if (output_shape[i] == 1) { @@ -82,27 +85,6 @@ class UnsqueezeOpInferShape : public framework::InferShapeBase { } }; -class UnsqueezeOp : public framework::OperatorBase { - public: - using OperatorBase::OperatorBase; - - private: - void RunImpl(const framework::Scope &scope, - const platform::Place &place) const override { - auto &axes = Attr>("axes"); - auto x_dims = scope.FindVar(Input("X"))->Get().dims(); - auto out_dims = UnsqueezeOpInferShape::GetOutputShape(axes, x_dims); - - framework::AttributeMap attrs; - attrs["shape"] = framework::vectorize2int(out_dims); - // Invoke Reshape op. - auto reshape_op = framework::OpRegistry::CreateOp( - "reshape", {{"X", {Input("X")}}, {"Shape", {}}}, - {{"Out", {Output("Out")}}}, attrs); - reshape_op->Run(scope, place); - } -}; - class UnsqueezeOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { @@ -112,17 +94,17 @@ class UnsqueezeOpMaker : public framework::OpProtoAndCheckerMaker { "(std::vector). List of integers," " indicating the dimensions to be inserted") .AddCustomChecker([](const std::vector &axes) { - PADDLE_ENFORCE(!axes.empty(), - "Invalid axes, The unsqueeze axes is empty."); + PADDLE_ENFORCE_EQ(!axes.empty(), true, + "Invalid axes, The unsqueeze axes is empty."); // Validity Check: axes dims (<6). - PADDLE_ENFORCE(static_cast(axes.size()) < 6, - "Invalid dimensions, dynamic dimensions should be " - "within [1, 6] dimensions (Eigen limit)."); + PADDLE_ENFORCE_LT(static_cast(axes.size()), 6, + "Invalid dimensions, dynamic dimensions should be " + "within [1, 6] dimensions (Eigen limit)."); // Validity Check: the range of unsqueeze aixs. for (int axis : axes) { - PADDLE_ENFORCE(axis < 6, - "Invalid dimensions, input axis should be" - " within [1, 6] dimensions (Eigen limit)."); + PADDLE_ENFORCE_LT(axis, 6, + "Invalid dimensions, input axis should be" + " within [1, 6] dimensions (Eigen limit)."); } }); AddComment(R"DOC( @@ -139,47 +121,47 @@ class UnsqueezeOpMaker : public framework::OpProtoAndCheckerMaker { } }; -class UnsqueezeGradInferShape : public framework::InferShapeBase { +class UnsqueezeGradOp : public framework::OperatorWithKernel { public: - void operator()(framework::InferShapeContext *ctx) const override { + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); ctx->ShareLoD("X", framework::GradVarName("X")); } }; -class UnsqueezeGradOp : public framework::OperatorBase { - public: - using OperatorBase::OperatorBase; - - private: - void RunImpl(const framework::Scope &scope, - const platform::Place &place) const override { - auto dx_name = Output(framework::GradVarName("X")); - auto dout_name = Input(framework::GradVarName("Out")); - auto x_dims = scope.FindVar(Input("X"))->Get().dims(); - - framework::AttributeMap attrs; - attrs["shape"] = framework::vectorize2int(x_dims); - - auto reshape_op = framework::OpRegistry::CreateOp( - "reshape", {{"X", {dout_name}}, {"Shape", {}}}, {{"Out", {dx_name}}}, - attrs); - reshape_op->Run(scope, place); - } -}; - // FIXME(zcd): unsqueeze2 adds an intermediate output(XShape) based on // unsqueeze, the XShape is used to carry the shape and lod of X which // will be used in unsqueeze_grad, in this way, the framework can reuse // the memory of X immediately the unsqueeze2_op is finished. // Considering compatibility issues, we could not fix unsqueeze2_op -class Unsqueeze2OpInferShape : public UnsqueezeOpInferShape { +class Unsqueeze2Op : public framework::OperatorWithKernel { public: - void operator()(framework::InferShapeContext *ctx) const override { - UnsqueezeOpInferShape::operator()(ctx); - PADDLE_ENFORCE(ctx->HasOutput("XShape"), - "Output(XShape) of Unsqueeze operator should not be null."); + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, + "Input(X) of Unsqueeze operator should not be null."); + PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, + "Output(Out) of Unsqueeze operator should not be null."); + + const auto &axes = ctx->Attrs().Get>("axes"); const auto &x_dims = ctx->GetInputDim("X"); + // Validity Check: input tensor dims (<6). + PADDLE_ENFORCE_LE(x_dims.size(), 6, + "Invalid dimensions, the rank of Input(X) " + "should be in the range of [1, 6] (Eigen limit)"); + auto out_dims = UnsqueezeOp::GetOutputShape(axes, x_dims); + ctx->SetOutputDim("Out", out_dims); + if (x_dims[0] == out_dims[0]) { + // Only pass LoD when the first dimension of output and Input(X) + // are the same. + ctx->ShareLoD("X", "Out"); + } + + PADDLE_ENFORCE_EQ( + ctx->HasOutput("XShape"), true, + "Output(XShape) of Unsqueeze operator should not be null."); std::vector xshape_dims(x_dims.size() + 1); xshape_dims[0] = 0; for (int i = 0; i < x_dims.size(); ++i) { @@ -201,27 +183,6 @@ class Unsqueeze2OpMaker : public UnsqueezeOpMaker { } }; -class Unsqueeze2Op : public framework::OperatorBase { - public: - using OperatorBase::OperatorBase; - - private: - void RunImpl(const framework::Scope &scope, - const platform::Place &place) const override { - auto &axes = Attr>("axes"); - auto x_dims = scope.FindVar(Input("X"))->Get().dims(); - auto out_dims = Unsqueeze2OpInferShape::GetOutputShape(axes, x_dims); - - framework::AttributeMap attrs; - attrs["shape"] = framework::vectorize2int(out_dims); - // Invoke Reshape op. - auto reshape_op = framework::OpRegistry::CreateOp( - "reshape2", {{"X", {Input("X")}}, {"Shape", {}}}, - {{"Out", {Output("Out")}}, {"XShape", {Output("XShape")}}}, attrs); - reshape_op->Run(scope, place); - } -}; - class Unsqueeze2GradOpMaker : public framework::SingleGradOpDescMaker { public: using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; @@ -237,43 +198,26 @@ class Unsqueeze2GradOpMaker : public framework::SingleGradOpDescMaker { } }; -class Unsqueeze2GradInferShape : public framework::InferShapeBase { +class Unsqueeze2GradOp : public framework::OperatorWithKernel { public: - void operator()(framework::InferShapeContext *context) const override { - PADDLE_ENFORCE(context->HasInput("XShape"), - "Input(XShape) shouldn't be null."); - PADDLE_ENFORCE(context->HasInput(framework::GradVarName("Out")), - "Input(Out@GRAD) shouldn't be null."); + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext *context) const override { + PADDLE_ENFORCE_EQ(context->HasInput("XShape"), true, + "Input(XShape) shouldn't be null."); + PADDLE_ENFORCE_EQ(context->HasInput(framework::GradVarName("Out")), true, + "Input(Out@GRAD) shouldn't be null."); auto xshape_dims = context->GetInputDim("XShape"); auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size()); context->SetOutputDim(framework::GradVarName("X"), x_dims); context->ShareLoD("XShape", framework::GradVarName("X")); } -}; - -class Unsqueeze2GradOp : public framework::OperatorBase { - public: - using OperatorBase::OperatorBase; - - private: - void RunImpl(const framework::Scope &scope, - const platform::Place &place) const override { - auto dx_name = Output(framework::GradVarName("X")); - auto dout_name = Input(framework::GradVarName("Out")); - auto xshape_name = Input("XShape"); - auto xshape_dims = - scope.FindVar(xshape_name)->Get().dims(); - auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size()); - - framework::AttributeMap attrs; - attrs["shape"] = framework::vectorize2int(x_dims); - auto reshape_op = framework::OpRegistry::CreateOp( - "reshape2_grad", {{framework::GradVarName("Out"), {dout_name}}, - {"Shape", {}}, - {"XShape", {xshape_name}}}, - {{framework::GradVarName("X"), {dx_name}}}, attrs); - reshape_op->Run(scope, place); + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType( + ctx.Input(framework::GradVarName("Out"))->type(), + ctx.device_context()); } }; @@ -281,23 +225,43 @@ DECLARE_INPLACE_OP_INFERER(UnsqueezeInplaceInferer, {"X", "Out"}); DECLARE_INPLACE_OP_INFERER(UnsqueezeGradInplaceInferer, {framework::GradVarName("Out"), framework::GradVarName("X")}); - } // namespace operators } // namespace paddle -// Tell linker to use reshape op. -USE_OP(reshape); - namespace ops = paddle::operators; REGISTER_OPERATOR(unsqueeze, ops::UnsqueezeOp, ops::UnsqueezeOpMaker, - ops::UnsqueezeOpInferShape, paddle::framework::DefaultGradOpDescMaker); -REGISTER_OPERATOR(unsqueeze_grad, ops::UnsqueezeGradOp, - ops::UnsqueezeGradInferShape); +REGISTER_OPERATOR(unsqueeze_grad, ops::UnsqueezeGradOp); REGISTER_OPERATOR(unsqueeze2, ops::Unsqueeze2Op, ops::Unsqueeze2OpMaker, - ops::Unsqueeze2OpInferShape, ops::Unsqueeze2GradOpMaker, - ops::UnsqueezeInplaceInferer); + ops::Unsqueeze2GradOpMaker, ops::UnsqueezeInplaceInferer); REGISTER_OPERATOR(unsqueeze2_grad, ops::Unsqueeze2GradOp, - ops::Unsqueeze2GradInferShape, ops::UnsqueezeGradInplaceInferer); + +REGISTER_OP_CPU_KERNEL( + unsqueeze, ops::UnsqueezeKernel, + ops::UnsqueezeKernel, + ops::UnsqueezeKernel, + ops::UnsqueezeKernel, + ops::UnsqueezeKernel); +REGISTER_OP_CPU_KERNEL( + unsqueeze_grad, + ops::UnsqueezeGradKernel, + ops::UnsqueezeGradKernel, + ops::UnsqueezeGradKernel, + ops::UnsqueezeGradKernel, + ops::UnsqueezeGradKernel); +REGISTER_OP_CPU_KERNEL( + unsqueeze2, + ops::Unsqueeze2Kernel, + ops::Unsqueeze2Kernel, + ops::Unsqueeze2Kernel, + ops::Unsqueeze2Kernel, + ops::Unsqueeze2Kernel); +REGISTER_OP_CPU_KERNEL( + unsqueeze2_grad, + ops::Unsqueeze2GradKernel, + ops::Unsqueeze2GradKernel, + ops::Unsqueeze2GradKernel, + ops::Unsqueeze2GradKernel, + ops::Unsqueeze2GradKernel); diff --git a/paddle/fluid/operators/unsqueeze_op.cu.cc b/paddle/fluid/operators/unsqueeze_op.cu.cc new file mode 100644 index 0000000000000000000000000000000000000000..fbdec5af94a570f430f9c50a16fe01b69a4f2d14 --- /dev/null +++ b/paddle/fluid/operators/unsqueeze_op.cu.cc @@ -0,0 +1,45 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/unsqueeze_op.h" + +namespace ops = paddle::operators; + +REGISTER_OP_CUDA_KERNEL( + unsqueeze, ops::UnsqueezeKernel, + ops::UnsqueezeKernel, + ops::UnsqueezeKernel, + ops::UnsqueezeKernel, + ops::UnsqueezeKernel); +REGISTER_OP_CUDA_KERNEL( + unsqueeze_grad, + ops::UnsqueezeGradKernel, + ops::UnsqueezeGradKernel, + ops::UnsqueezeGradKernel, + ops::UnsqueezeGradKernel, + ops::UnsqueezeGradKernel); +REGISTER_OP_CUDA_KERNEL( + unsqueeze2, + ops::Unsqueeze2Kernel, + ops::Unsqueeze2Kernel, + ops::Unsqueeze2Kernel, + ops::Unsqueeze2Kernel, + ops::Unsqueeze2Kernel); +REGISTER_OP_CUDA_KERNEL( + unsqueeze2_grad, + ops::Unsqueeze2GradKernel, + ops::Unsqueeze2GradKernel, + ops::Unsqueeze2GradKernel, + ops::Unsqueeze2GradKernel, + ops::Unsqueeze2GradKernel); diff --git a/paddle/fluid/operators/unsqueeze_op.h b/paddle/fluid/operators/unsqueeze_op.h new file mode 100644 index 0000000000000000000000000000000000000000..68f0cbe81223126c3f850a6e738c7b581910c69d --- /dev/null +++ b/paddle/fluid/operators/unsqueeze_op.h @@ -0,0 +1,137 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/operators/math/pooling.h" +#include "paddle/fluid/platform/device_context.h" + +namespace paddle { +namespace operators { + +template +class UnsqueezeKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + auto &axes = context.Attr>("axes"); + auto *in = context.Input("X"); + auto *out = context.Output("Out"); + auto x_dims = in->dims(); + auto out_dims = GetOutputShape(axes, x_dims); + + out->mutable_data(context.GetPlace(), in->type()); + framework::TensorCopy( + *in, context.GetPlace(), + context.template device_context(), out); + out->Resize(out_dims); + } + + static framework::DDim GetOutputShape(const std::vector unsqz_dims, + const framework::DDim &in_dims) { + int output_size = in_dims.size() + static_cast(unsqz_dims.size()); + int cur_output_size = in_dims.size(); + std::vector output_shape(output_size, 0); + + // Validity Check: rank range. + PADDLE_ENFORCE_LE(output_size, 6, + "The output tensor's rank should be less than 6."); + + for (int axis : unsqz_dims) { + int cur = axis < 0 ? axis + cur_output_size + 1 : axis; + // Vaildity Check: the axis bound + PADDLE_ENFORCE_GE(cur, 0); + PADDLE_ENFORCE_LE(cur, cur_output_size); + // Move old axis, and insert new axis + for (int i = cur_output_size; i >= cur; --i) { + if (output_shape[i] == 1) { + // Move axis + output_shape[i + 1] = 1; + output_shape[i] = 0; + } + } + output_shape[cur] = 1; + // Add the output size. + cur_output_size++; + } + + // Make output shape + for (int in_idx = 0, out_idx = 0; out_idx < output_size; ++out_idx) { + if (output_shape[out_idx] == 0) { + output_shape[out_idx] = in_dims[in_idx++]; + } + } + + return framework::make_ddim(output_shape); + } +}; + +template +class UnsqueezeGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *d_out = + ctx.Input(framework::GradVarName("Out")); + auto *d_x = ctx.Output(framework::GradVarName("X")); + auto in_dims = ctx.Input("X")->dims(); + + d_x->mutable_data(ctx.GetPlace(), d_out->type()); + framework::TensorCopySync(*d_out, ctx.GetPlace(), d_x); + d_x->Resize(in_dims); + } +}; + +template +class Unsqueeze2Kernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + auto *out = context.Output("Out"); + auto *in = context.Input("X"); + + auto &axes = context.Attr>("axes"); + + auto x_dims = in->dims(); + auto out_dims = + UnsqueezeKernel::GetOutputShape(axes, x_dims); + + out->mutable_data(context.GetPlace(), in->type()); + framework::TensorCopy( + *in, context.GetPlace(), + context.template device_context(), out); + out->Resize(out_dims); + } +}; + +template +class Unsqueeze2GradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *d_out = + ctx.Input(framework::GradVarName("Out")); + auto *d_x = ctx.Output(framework::GradVarName("X")); + // auto in_dims = d_x->dims(); + + auto xshape_dims = ctx.Input("XShape")->dims(); + auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size()); + + d_x->mutable_data(ctx.GetPlace(), d_out->type()); + framework::TensorCopySync(*d_out, ctx.GetPlace(), d_x); + d_x->Resize(x_dims); + } +}; +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_unsqueeze2_op.py b/python/paddle/fluid/tests/unittests/test_unsqueeze2_op.py new file mode 100644 index 0000000000000000000000000000000000000000..14dd2bb06f9a18d0b15a4aee4e9e6bfdf8c41206 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_unsqueeze2_op.py @@ -0,0 +1,83 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np + +from op_test import OpTest + + +# Correct: General. +class TestUnsqueezeOp(OpTest): + def setUp(self): + self.init_test_case() + self.op_type = "unsqueeze2" + self.inputs = {"X": np.random.random(self.ori_shape).astype("float32")} + self.init_attrs() + self.outputs = { + "Out": self.inputs["X"].reshape(self.new_shape), + "XShape": np.random.random(self.ori_shape).astype("float32") + } + + def test_check_output(self): + self.check_output(no_check_set=["XShape"]) + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + def init_test_case(self): + self.ori_shape = (3, 5) + self.axes = (1, 2) + self.new_shape = (3, 1, 1, 5) + + def init_attrs(self): + self.attrs = {"axes": self.axes} + + +# Correct: Single input index. +class TestUnsqueezeOp1(TestUnsqueezeOp): + def init_test_case(self): + self.ori_shape = (3, 5) + self.axes = (-1, ) + self.new_shape = (3, 5, 1) + + +# Correct: Mixed input axis. +class TestUnsqueezeOp2(TestUnsqueezeOp): + def init_test_case(self): + self.ori_shape = (3, 5) + self.axes = (0, -1) + self.new_shape = (1, 3, 5, 1) + + +# Correct: There is duplicated axis. +class TestUnsqueezeOp3(TestUnsqueezeOp): + def init_test_case(self): + self.ori_shape = (3, 2, 5) + self.axes = (0, 3, 3) + self.new_shape = (1, 3, 2, 1, 1, 5) + + +# Correct: Reversed axes. +class TestUnsqueezeOp4(TestUnsqueezeOp): + def init_test_case(self): + self.ori_shape = (3, 2, 5) + self.axes = (3, 1, 1) + self.new_shape = (3, 1, 1, 2, 5, 1) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py b/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py index 14dd2bb06f9a18d0b15a4aee4e9e6bfdf8c41206..a324438ba5a3c3b57fd956bd11189ef7d50267e2 100644 --- a/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py +++ b/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py @@ -24,16 +24,13 @@ from op_test import OpTest class TestUnsqueezeOp(OpTest): def setUp(self): self.init_test_case() - self.op_type = "unsqueeze2" + self.op_type = "unsqueeze" self.inputs = {"X": np.random.random(self.ori_shape).astype("float32")} self.init_attrs() - self.outputs = { - "Out": self.inputs["X"].reshape(self.new_shape), - "XShape": np.random.random(self.ori_shape).astype("float32") - } + self.outputs = {"Out": self.inputs["X"].reshape(self.new_shape)} def test_check_output(self): - self.check_output(no_check_set=["XShape"]) + self.check_output() def test_check_grad(self): self.check_grad(["X"], "Out")