diff --git a/paddle/fluid/operators/squeeze_op.cc b/paddle/fluid/operators/squeeze_op.cc index 114fab2488f497bbd0d476e76e191e93086263ef..7aeb1d961b1b53105131336eea9ef2a798c65213 100644 --- a/paddle/fluid/operators/squeeze_op.cc +++ b/paddle/fluid/operators/squeeze_op.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,26 +12,31 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/fluid/operators/squeeze_op.h" +#include #include +#include #include #include "paddle/fluid/framework/op_registry.h" namespace paddle { namespace operators { -class SqueezeOpInferShape : public framework::InferShapeBase { +class SqueezeOp : public framework::OperatorWithKernel { public: - void operator()(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), - "Input(X) of Squeeze operator should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("Out"), - "Output(Out) of Squeeze operator should not be null."); + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, + "Input(X) of Squeeze operator should not be null."); + PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, + "Output(Out) of Squeeze operator should not be null."); const auto &x_dims = ctx->GetInputDim("X"); // Check input tensor dims (<6) Eigen limit. - PADDLE_ENFORCE(x_dims.size() <= 6, - "Invalid dimnesions, the rank of Input(X) " - "should be in the range of [1, 6] (Eigen limit)."); + PADDLE_ENFORCE_LE(x_dims.size(), 6, + "Invalid dimnesions, the rank of Input(X) " + "should be in the range of [1, 6] (Eigen limit)."); const auto &axes = ctx->Attrs().Get>("axes"); for (int a : axes) { @@ -40,7 +45,7 @@ class SqueezeOpInferShape : public framework::InferShapeBase { "tensor's rank."); } - auto out_dims = GetOutputShape(axes, x_dims, false); + auto out_dims = GetOutputShape(axes, x_dims); ctx->SetOutputDim("Out", out_dims); if (x_dims[0] == out_dims[0]) { // Only pass LoD when the first dimension of output and Input(X) @@ -50,8 +55,7 @@ class SqueezeOpInferShape : public framework::InferShapeBase { } static framework::DDim GetOutputShape(const std::vector squeeze_dims, - const framework::DDim &in_dims, - bool is_runtime) { + const framework::DDim &in_dims) { size_t num_squeeze_dims = squeeze_dims.size(); int cnt_squeezed_dims = 0; bool should_squeeze[9] = {false}; @@ -70,14 +74,8 @@ class SqueezeOpInferShape : public framework::InferShapeBase { int current = squeeze_dims[idx] < 0 ? squeeze_dims[idx] + in_dims.size() : squeeze_dims[idx]; // Check current index, the upper limit has beed checked in line 36. - PADDLE_ENFORCE(current >= 0, - "Invalid axis, the negative axis is out of range."); - - if (is_runtime) { - PADDLE_ENFORCE(in_dims[current] == 1, - "Invalid axis index, the axis that will be squeezed " - "should be equal to 1."); - } + PADDLE_ENFORCE_GE(current, 0, + "Invalid axis, the negative axis is out of range."); if (!(should_squeeze[current])) { ++cnt_squeezed_dims; @@ -96,27 +94,30 @@ class SqueezeOpInferShape : public framework::InferShapeBase { return framework::make_ddim(output_shape); } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.device_context()); + } }; -// TODO(paddle-dev): Should use OpKernel. -class SqueezeOp : public framework::OperatorBase { +class SqueezeGradOp : public framework::OperatorWithKernel { public: - using OperatorBase::OperatorBase; - - private: - void RunImpl(const framework::Scope &scope, - const platform::Place &place) const override { - auto &axes = Attr>("axes"); - auto x_dims = scope.FindVar(Input("X"))->Get().dims(); - auto out_dims = SqueezeOpInferShape::GetOutputShape(axes, x_dims, true); - - framework::AttributeMap attrs; - attrs["shape"] = framework::vectorize(out_dims); - // Invoke Reshape Op - auto reshape_op = framework::OpRegistry::CreateOp( - "reshape", {{"X", {Input("X")}}, {"Shape", {}}}, - {{"Out", {Output("Out")}}}, attrs); - reshape_op->Run(scope, place); + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *context) const override { + context->SetOutputDim(framework::GradVarName("X"), + context->GetInputDim("X")); + context->ShareLoD("X", framework::GradVarName("X")); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.device_context()); } }; @@ -157,32 +158,70 @@ class SqueezeOpMaker : public framework::OpProtoAndCheckerMaker { } }; -class SqueezeGradInferShape : public framework::InferShapeBase { +class Squeeze2Op : public framework::OperatorWithKernel { public: - void operator()(framework::InferShapeContext *context) const override { - context->SetOutputDim(framework::GradVarName("X"), - context->GetInputDim("X")); - context->ShareLoD("X", framework::GradVarName("X")); + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, + "Input(X) of Squeeze operator should not be null."); + PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, + "Output(Out) of Squeeze operator should not be null."); + + const auto &x_dims = ctx->GetInputDim("X"); + // Check input tensor dims (<6) Eigen limit. + PADDLE_ENFORCE_LE(x_dims.size(), 6, + "Invalid dimnesions, the rank of Input(X) " + "should be in the range of [1, 6] (Eigen limit)."); + + const auto &axes = ctx->Attrs().Get>("axes"); + for (int a : axes) { + PADDLE_ENFORCE_LT(a, x_dims.size(), + "The squeeze axis should be less than input " + "tensor's rank."); + } + + auto out_dims = SqueezeOp::GetOutputShape(axes, x_dims); + ctx->SetOutputDim("Out", out_dims); + if (x_dims[0] == out_dims[0]) { + // Only pass LoD when the first dimension of output and Input(X) + // are the same. + ctx->ShareLoD("X", "Out"); + } + + PADDLE_ENFORCE_EQ(ctx->HasOutput("XShape"), true, + "Output(XShape) of Squeeze operator should not be null."); + std::vector xshape_dims(x_dims.size() + 1); + xshape_dims[0] = 0; + for (int i = 0; i < x_dims.size(); ++i) { + xshape_dims[i + 1] = x_dims[i]; + } + ctx->SetOutputDim("XShape", framework::make_ddim(xshape_dims)); + ctx->ShareLoD("X", /*->*/ "XShape"); } }; -class SqueezeGradOp : public framework::OperatorBase { +class Squeeze2GradOp : public framework::OperatorWithKernel { public: - using OperatorBase::OperatorBase; - - private: - void RunImpl(const framework::Scope &scope, - const platform::Place &place) const override { - auto dx_name = Output(framework::GradVarName("X")); - auto dout_name = Input(framework::GradVarName("Out")); - auto x_dims = scope.FindVar(Input("X"))->Get().dims(); - framework::AttributeMap attrs; - attrs["shape"] = framework::vectorize(x_dims); - - auto reshape_op = framework::OpRegistry::CreateOp( - "reshape", {{"X", {dout_name}}, {"Shape", {}}}, {{"Out", {dx_name}}}, - attrs); - reshape_op->Run(scope, place); + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *context) const override { + PADDLE_ENFORCE_EQ(context->HasInput("XShape"), true, + "Input(XShape) shouldn't be null."); + PADDLE_ENFORCE_EQ(context->HasInput(framework::GradVarName("Out")), true, + "Input(Out@GRAD) shouldn't be null."); + auto xshape_dims = context->GetInputDim("XShape"); + auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size()); + context->SetOutputDim(framework::GradVarName("X"), x_dims); + context->ShareLoD("XShape", framework::GradVarName("X")); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType( + ctx.Input(framework::GradVarName("Out"))->type(), + ctx.device_context()); } }; @@ -202,44 +241,6 @@ class Squeeze2OpMaker : public SqueezeOpMaker { } }; -class Squeeze2OpInferShape : public SqueezeOpInferShape { - public: - void operator()(framework::InferShapeContext *ctx) const override { - SqueezeOpInferShape::operator()(ctx); - PADDLE_ENFORCE(ctx->HasOutput("XShape"), - "Output(XShape) of Squeeze operator should not be null."); - const auto &x_dims = ctx->GetInputDim("X"); - std::vector xshape_dims(x_dims.size() + 1); - xshape_dims[0] = 0; - for (int i = 0; i < x_dims.size(); ++i) { - xshape_dims[i + 1] = x_dims[i]; - } - ctx->SetOutputDim("XShape", framework::make_ddim(xshape_dims)); - ctx->ShareLoD("X", /*->*/ "XShape"); - } -}; - -class Squeeze2Op : public framework::OperatorBase { - public: - using OperatorBase::OperatorBase; - - private: - void RunImpl(const framework::Scope &scope, - const platform::Place &place) const override { - auto &axes = Attr>("axes"); - auto x_dims = scope.FindVar(Input("X"))->Get().dims(); - auto out_dims = Squeeze2OpInferShape::GetOutputShape(axes, x_dims, true); - - framework::AttributeMap attrs; - attrs["shape"] = framework::vectorize(out_dims); - // Invoke Reshape Op - auto reshape_op = framework::OpRegistry::CreateOp( - "reshape2", {{"X", {Input("X")}}, {"Shape", {}}}, - {{"Out", {Output("Out")}}, {"XShape", {Output("XShape")}}}, attrs); - reshape_op->Run(scope, place); - } -}; - class Squeeze2GradOpMaker : public framework::SingleGradOpDescMaker { public: using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; @@ -255,46 +256,6 @@ class Squeeze2GradOpMaker : public framework::SingleGradOpDescMaker { } }; -class Squeeze2GradInferShape : public framework::InferShapeBase { - public: - void operator()(framework::InferShapeContext *context) const override { - PADDLE_ENFORCE(context->HasInput("XShape"), - "Input(XShape) shouldn't be null."); - PADDLE_ENFORCE(context->HasInput(framework::GradVarName("Out")), - "Input(Out@GRAD) shouldn't be null."); - auto xshape_dims = context->GetInputDim("XShape"); - auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size()); - context->SetOutputDim(framework::GradVarName("X"), x_dims); - context->ShareLoD("XShape", framework::GradVarName("X")); - } -}; - -class Squeeze2GradOp : public framework::OperatorBase { - public: - using OperatorBase::OperatorBase; - - private: - void RunImpl(const framework::Scope &scope, - const platform::Place &place) const override { - auto dx_name = Output(framework::GradVarName("X")); - auto dout_name = Input(framework::GradVarName("Out")); - auto xshape_name = Input("XShape"); - auto xshape_dims = - scope.FindVar(xshape_name)->Get().dims(); - auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size()); - - framework::AttributeMap attrs; - attrs["shape"] = framework::vectorize(x_dims); - - auto reshape_op = framework::OpRegistry::CreateOp( - "reshape2_grad", {{framework::GradVarName("Out"), {dout_name}}, - {"Shape", {}}, - {"XShape", {xshape_name}}}, - {{framework::GradVarName("X"), {dx_name}}}, attrs); - reshape_op->Run(scope, place); - } -}; - DECLARE_INPLACE_OP_INFERER(SequeezeInplaceInferer, {"X", "Out"}); DECLARE_INPLACE_OP_INFERER(SequeezeGradInplaceInferer, {framework::GradVarName("Out"), @@ -303,17 +264,39 @@ DECLARE_INPLACE_OP_INFERER(SequeezeGradInplaceInferer, } // namespace operators } // namespace paddle -// Tell linker to use reshape op -USE_OP(reshape); - namespace ops = paddle::operators; REGISTER_OPERATOR(squeeze, ops::SqueezeOp, ops::SqueezeOpMaker, - ops::SqueezeOpInferShape, paddle::framework::DefaultGradOpDescMaker); -REGISTER_OPERATOR(squeeze_grad, ops::SqueezeGradOp, ops::SqueezeGradInferShape); +REGISTER_OPERATOR(squeeze_grad, ops::SqueezeGradOp); REGISTER_OPERATOR(squeeze2, ops::Squeeze2Op, ops::Squeeze2OpMaker, - ops::Squeeze2OpInferShape, ops::Squeeze2GradOpMaker, - ops::SequeezeInplaceInferer); + ops::Squeeze2GradOpMaker, ops::SequeezeInplaceInferer); REGISTER_OPERATOR(squeeze2_grad, ops::Squeeze2GradOp, - ops::Squeeze2GradInferShape, ops::SequeezeGradInplaceInferer); + ops::SequeezeGradInplaceInferer); + +REGISTER_OP_CPU_KERNEL( + squeeze, ops::SqueezeKernel, + ops::SqueezeKernel, + ops::SqueezeKernel, + ops::SqueezeKernel, + ops::SqueezeKernel); +REGISTER_OP_CPU_KERNEL( + squeeze_grad, + ops::SqueezeGradKernel, + ops::SqueezeGradKernel, + ops::SqueezeGradKernel, + ops::SqueezeGradKernel, + ops::SqueezeGradKernel); +REGISTER_OP_CPU_KERNEL( + squeeze2, ops::Squeeze2Kernel, + ops::Squeeze2Kernel, + ops::Squeeze2Kernel, + ops::Squeeze2Kernel, + ops::Squeeze2Kernel); +REGISTER_OP_CPU_KERNEL( + squeeze2_grad, + ops::Squeeze2GradKernel, + ops::Squeeze2GradKernel, + ops::Squeeze2GradKernel, + ops::Squeeze2GradKernel, + ops::Squeeze2GradKernel); diff --git a/paddle/fluid/operators/squeeze_op.cu.cc b/paddle/fluid/operators/squeeze_op.cu.cc new file mode 100644 index 0000000000000000000000000000000000000000..50fee1497e989fc2df93292253010a212c78a54f --- /dev/null +++ b/paddle/fluid/operators/squeeze_op.cu.cc @@ -0,0 +1,44 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/squeeze_op.h" + +namespace ops = paddle::operators; + +REGISTER_OP_CUDA_KERNEL( + squeeze, ops::SqueezeKernel, + ops::SqueezeKernel, + ops::SqueezeKernel, + ops::SqueezeKernel, + ops::SqueezeKernel); +REGISTER_OP_CUDA_KERNEL( + squeeze_grad, + ops::SqueezeGradKernel, + ops::SqueezeGradKernel, + ops::SqueezeGradKernel, + ops::SqueezeGradKernel, + ops::SqueezeGradKernel); +REGISTER_OP_CUDA_KERNEL( + squeeze2, ops::Squeeze2Kernel, + ops::Squeeze2Kernel, + ops::Squeeze2Kernel, + ops::Squeeze2Kernel, + ops::Squeeze2Kernel); +REGISTER_OP_CUDA_KERNEL( + squeeze2_grad, + ops::Squeeze2GradKernel, + ops::Squeeze2GradKernel, + ops::Squeeze2GradKernel, + ops::Squeeze2GradKernel, + ops::Squeeze2GradKernel); diff --git a/paddle/fluid/operators/squeeze_op.h b/paddle/fluid/operators/squeeze_op.h new file mode 100644 index 0000000000000000000000000000000000000000..5aae186527543dfe6e36de59fac058524e66bf59 --- /dev/null +++ b/paddle/fluid/operators/squeeze_op.h @@ -0,0 +1,146 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/operators/math/pooling.h" +#include "paddle/fluid/platform/device_context.h" + +namespace paddle { +namespace operators { + +template +class SqueezeKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + auto *in = context.Input("X"); + auto *out = context.Output("Out"); + + auto &axes = context.Attr>("axes"); + auto x_dims = in->dims(); + auto out_dims = GetOutputShape(axes, x_dims); + + out->mutable_data(context.GetPlace(), in->type()); + framework::TensorCopy( + *in, context.GetPlace(), + context.template device_context(), out); + out->Resize(out_dims); + } + + static framework::DDim GetOutputShape(const std::vector squeeze_dims, + const framework::DDim &in_dims) { + size_t num_squeeze_dims = squeeze_dims.size(); + int cnt_squeezed_dims = 0; + bool should_squeeze[9] = {false}; + + // Determines number of dimensions of output tensor after squeeze. + // Mark and count the dimensions need to be squeezed + if (num_squeeze_dims == 0) { + for (int idx = 0; idx < in_dims.size(); ++idx) { + if (in_dims[idx] == 1) { + should_squeeze[idx] = true; + ++cnt_squeezed_dims; + } + } + } else { + for (size_t idx = 0; idx < num_squeeze_dims; ++idx) { + int current = squeeze_dims[idx] < 0 ? squeeze_dims[idx] + in_dims.size() + : squeeze_dims[idx]; + // Check current index, the upper limit has beed checked in line 36. + PADDLE_ENFORCE_GE(current, 0, + "Invalid axis, the negative axis is out of range."); + + PADDLE_ENFORCE_EQ(in_dims[current], 1, + "Invalid axis index, the axis that will be squeezed " + "should be equal to 1."); + + if (!(should_squeeze[current])) { + ++cnt_squeezed_dims; + } + should_squeeze[current] = true; + } + } + + // Make output dimensions + std::vector output_shape(in_dims.size() - cnt_squeezed_dims, 0); + for (int in_idx = 0, out_idx = 0; in_idx < in_dims.size(); ++in_idx) { + if (!should_squeeze[in_idx]) { + output_shape[out_idx++] = in_dims[in_idx]; + } + } + + return framework::make_ddim(output_shape); + } +}; + +template +class SqueezeGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *d_out = + ctx.Input(framework::GradVarName("Out")); + auto *d_x = ctx.Output(framework::GradVarName("X")); + auto in_dims = ctx.Input("X")->dims(); + + d_x->mutable_data(ctx.GetPlace(), d_out->type()); + framework::TensorCopySync(*d_out, ctx.GetPlace(), d_x); + d_x->Resize(in_dims); + } +}; + +template +class Squeeze2Kernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + auto *out = context.Output("Out"); + auto *in = context.Input("X"); + + auto &axes = context.Attr>("axes"); + + auto x_dims = in->dims(); + auto out_dims = + SqueezeKernel::GetOutputShape(axes, x_dims); + + out->mutable_data(context.GetPlace(), in->type()); + framework::TensorCopy( + *in, context.GetPlace(), + context.template device_context(), out); + out->Resize(out_dims); + } +}; + +template +class Squeeze2GradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *d_out = + ctx.Input(framework::GradVarName("Out")); + auto *d_x = ctx.Output(framework::GradVarName("X")); + // auto in_dims = d_x->dims(); + + auto xshape_dims = ctx.Input("XShape")->dims(); + auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size()); + + d_x->mutable_data(ctx.GetPlace(), d_out->type()); + framework::TensorCopySync(*d_out, ctx.GetPlace(), d_x); + d_x->Resize(x_dims); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_squeeze2_op.py b/python/paddle/fluid/tests/unittests/test_squeeze2_op.py new file mode 100644 index 0000000000000000000000000000000000000000..ad9391eac3304965d6ee5d007fce70a5d0dd1b18 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_squeeze2_op.py @@ -0,0 +1,75 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np + +from op_test import OpTest + + +# Correct: General. +class TestSqueezeOp(OpTest): + def setUp(self): + self.op_type = "squeeze2" + self.init_test_case() + self.inputs = {"X": np.random.random(self.ori_shape).astype("float32")} + self.init_attrs() + self.outputs = { + "Out": self.inputs["X"].reshape(self.new_shape), + "XShape": np.random.random(self.ori_shape).astype("float32") + } + + def test_check_output(self): + self.check_output(no_check_set=['XShape']) + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + def init_test_case(self): + self.ori_shape = (1, 3, 1, 5) + self.axes = (0, 2) + self.new_shape = (3, 5) + + def init_attrs(self): + self.attrs = {"axes": self.axes} + + +# Correct: There is mins axis. +class TestSqueezeOp1(TestSqueezeOp): + def init_test_case(self): + self.ori_shape = (1, 3, 1, 5) + self.axes = (0, -2) + self.new_shape = (3, 5) + + +# Correct: No axes input. +class TestSqueezeOp2(TestSqueezeOp): + def init_test_case(self): + self.ori_shape = (1, 3, 1, 5) + self.axes = () + self.new_shape = (3, 5) + + +# Correct: Just part of axes be squeezed. +class TestSqueezeOp3(TestSqueezeOp): + def init_test_case(self): + self.ori_shape = (3, 1, 5, 1, 4, 1) + self.axes = (1, -1) + self.new_shape = (3, 5, 1, 4) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_squeeze_op.py b/python/paddle/fluid/tests/unittests/test_squeeze_op.py index 204a4bb40196bd1fc2f5861aa31cf9560ea4d349..8a43f5c3e1e31099da155ba7d730c5085f7d26d2 100644 --- a/python/paddle/fluid/tests/unittests/test_squeeze_op.py +++ b/python/paddle/fluid/tests/unittests/test_squeeze_op.py @@ -1,4 +1,4 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -23,17 +23,14 @@ from op_test import OpTest # Correct: General. class TestSqueezeOp(OpTest): def setUp(self): - self.op_type = "squeeze2" + self.op_type = "squeeze" self.init_test_case() self.inputs = {"X": np.random.random(self.ori_shape).astype("float32")} self.init_attrs() - self.outputs = { - "Out": self.inputs["X"].reshape(self.new_shape), - "XShape": np.random.random(self.ori_shape).astype("float32") - } + self.outputs = {"Out": self.inputs["X"].reshape(self.new_shape), } def test_check_output(self): - self.check_output(no_check_set=['XShape']) + self.check_output() def test_check_grad(self): self.check_grad(["X"], "Out")