提交 52673956 编写于 作者: Z zhongpu 提交者: Jiabin Yang

add kernel for squeeze_op, test=develop (#19656)

* add kernel for squeeze_op, test=develop

* delete comment, test=develop
上级 2a81c367
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -12,26 +12,31 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/squeeze_op.h"
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
class SqueezeOpInferShape : public framework::InferShapeBase {
class SqueezeOp : public framework::OperatorWithKernel {
public:
void operator()(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of Squeeze operator should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of Squeeze operator should not be null.");
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
"Input(X) of Squeeze operator should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
"Output(Out) of Squeeze operator should not be null.");
const auto &x_dims = ctx->GetInputDim("X");
// Check input tensor dims (<6) Eigen limit.
PADDLE_ENFORCE(x_dims.size() <= 6,
"Invalid dimnesions, the rank of Input(X) "
"should be in the range of [1, 6] (Eigen limit).");
PADDLE_ENFORCE_LE(x_dims.size(), 6,
"Invalid dimnesions, the rank of Input(X) "
"should be in the range of [1, 6] (Eigen limit).");
const auto &axes = ctx->Attrs().Get<std::vector<int>>("axes");
for (int a : axes) {
......@@ -40,7 +45,7 @@ class SqueezeOpInferShape : public framework::InferShapeBase {
"tensor's rank.");
}
auto out_dims = GetOutputShape(axes, x_dims, false);
auto out_dims = GetOutputShape(axes, x_dims);
ctx->SetOutputDim("Out", out_dims);
if (x_dims[0] == out_dims[0]) {
// Only pass LoD when the first dimension of output and Input(X)
......@@ -50,8 +55,7 @@ class SqueezeOpInferShape : public framework::InferShapeBase {
}
static framework::DDim GetOutputShape(const std::vector<int> squeeze_dims,
const framework::DDim &in_dims,
bool is_runtime) {
const framework::DDim &in_dims) {
size_t num_squeeze_dims = squeeze_dims.size();
int cnt_squeezed_dims = 0;
bool should_squeeze[9] = {false};
......@@ -70,14 +74,8 @@ class SqueezeOpInferShape : public framework::InferShapeBase {
int current = squeeze_dims[idx] < 0 ? squeeze_dims[idx] + in_dims.size()
: squeeze_dims[idx];
// Check current index, the upper limit has beed checked in line 36.
PADDLE_ENFORCE(current >= 0,
"Invalid axis, the negative axis is out of range.");
if (is_runtime) {
PADDLE_ENFORCE(in_dims[current] == 1,
"Invalid axis index, the axis that will be squeezed "
"should be equal to 1.");
}
PADDLE_ENFORCE_GE(current, 0,
"Invalid axis, the negative axis is out of range.");
if (!(should_squeeze[current])) {
++cnt_squeezed_dims;
......@@ -96,27 +94,30 @@ class SqueezeOpInferShape : public framework::InferShapeBase {
return framework::make_ddim(output_shape);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
ctx.device_context());
}
};
// TODO(paddle-dev): Should use OpKernel.
class SqueezeOp : public framework::OperatorBase {
class SqueezeGradOp : public framework::OperatorWithKernel {
public:
using OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto &axes = Attr<std::vector<int>>("axes");
auto x_dims = scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims();
auto out_dims = SqueezeOpInferShape::GetOutputShape(axes, x_dims, true);
framework::AttributeMap attrs;
attrs["shape"] = framework::vectorize<int>(out_dims);
// Invoke Reshape Op
auto reshape_op = framework::OpRegistry::CreateOp(
"reshape", {{"X", {Input("X")}}, {"Shape", {}}},
{{"Out", {Output("Out")}}}, attrs);
reshape_op->Run(scope, place);
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *context) const override {
context->SetOutputDim(framework::GradVarName("X"),
context->GetInputDim("X"));
context->ShareLoD("X", framework::GradVarName("X"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
ctx.device_context());
}
};
......@@ -157,32 +158,70 @@ class SqueezeOpMaker : public framework::OpProtoAndCheckerMaker {
}
};
class SqueezeGradInferShape : public framework::InferShapeBase {
class Squeeze2Op : public framework::OperatorWithKernel {
public:
void operator()(framework::InferShapeContext *context) const override {
context->SetOutputDim(framework::GradVarName("X"),
context->GetInputDim("X"));
context->ShareLoD("X", framework::GradVarName("X"));
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
"Input(X) of Squeeze operator should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
"Output(Out) of Squeeze operator should not be null.");
const auto &x_dims = ctx->GetInputDim("X");
// Check input tensor dims (<6) Eigen limit.
PADDLE_ENFORCE_LE(x_dims.size(), 6,
"Invalid dimnesions, the rank of Input(X) "
"should be in the range of [1, 6] (Eigen limit).");
const auto &axes = ctx->Attrs().Get<std::vector<int>>("axes");
for (int a : axes) {
PADDLE_ENFORCE_LT(a, x_dims.size(),
"The squeeze axis should be less than input "
"tensor's rank.");
}
auto out_dims = SqueezeOp::GetOutputShape(axes, x_dims);
ctx->SetOutputDim("Out", out_dims);
if (x_dims[0] == out_dims[0]) {
// Only pass LoD when the first dimension of output and Input(X)
// are the same.
ctx->ShareLoD("X", "Out");
}
PADDLE_ENFORCE_EQ(ctx->HasOutput("XShape"), true,
"Output(XShape) of Squeeze operator should not be null.");
std::vector<int64_t> xshape_dims(x_dims.size() + 1);
xshape_dims[0] = 0;
for (int i = 0; i < x_dims.size(); ++i) {
xshape_dims[i + 1] = x_dims[i];
}
ctx->SetOutputDim("XShape", framework::make_ddim(xshape_dims));
ctx->ShareLoD("X", /*->*/ "XShape");
}
};
class SqueezeGradOp : public framework::OperatorBase {
class Squeeze2GradOp : public framework::OperatorWithKernel {
public:
using OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto dx_name = Output(framework::GradVarName("X"));
auto dout_name = Input(framework::GradVarName("Out"));
auto x_dims = scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims();
framework::AttributeMap attrs;
attrs["shape"] = framework::vectorize<int>(x_dims);
auto reshape_op = framework::OpRegistry::CreateOp(
"reshape", {{"X", {dout_name}}, {"Shape", {}}}, {{"Out", {dx_name}}},
attrs);
reshape_op->Run(scope, place);
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE_EQ(context->HasInput("XShape"), true,
"Input(XShape) shouldn't be null.");
PADDLE_ENFORCE_EQ(context->HasInput(framework::GradVarName("Out")), true,
"Input(Out@GRAD) shouldn't be null.");
auto xshape_dims = context->GetInputDim("XShape");
auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size());
context->SetOutputDim(framework::GradVarName("X"), x_dims);
context->ShareLoD("XShape", framework::GradVarName("X"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))->type(),
ctx.device_context());
}
};
......@@ -202,44 +241,6 @@ class Squeeze2OpMaker : public SqueezeOpMaker {
}
};
class Squeeze2OpInferShape : public SqueezeOpInferShape {
public:
void operator()(framework::InferShapeContext *ctx) const override {
SqueezeOpInferShape::operator()(ctx);
PADDLE_ENFORCE(ctx->HasOutput("XShape"),
"Output(XShape) of Squeeze operator should not be null.");
const auto &x_dims = ctx->GetInputDim("X");
std::vector<int64_t> xshape_dims(x_dims.size() + 1);
xshape_dims[0] = 0;
for (int i = 0; i < x_dims.size(); ++i) {
xshape_dims[i + 1] = x_dims[i];
}
ctx->SetOutputDim("XShape", framework::make_ddim(xshape_dims));
ctx->ShareLoD("X", /*->*/ "XShape");
}
};
class Squeeze2Op : public framework::OperatorBase {
public:
using OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto &axes = Attr<std::vector<int>>("axes");
auto x_dims = scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims();
auto out_dims = Squeeze2OpInferShape::GetOutputShape(axes, x_dims, true);
framework::AttributeMap attrs;
attrs["shape"] = framework::vectorize<int>(out_dims);
// Invoke Reshape Op
auto reshape_op = framework::OpRegistry::CreateOp(
"reshape2", {{"X", {Input("X")}}, {"Shape", {}}},
{{"Out", {Output("Out")}}, {"XShape", {Output("XShape")}}}, attrs);
reshape_op->Run(scope, place);
}
};
class Squeeze2GradOpMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
......@@ -255,46 +256,6 @@ class Squeeze2GradOpMaker : public framework::SingleGradOpDescMaker {
}
};
class Squeeze2GradInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE(context->HasInput("XShape"),
"Input(XShape) shouldn't be null.");
PADDLE_ENFORCE(context->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) shouldn't be null.");
auto xshape_dims = context->GetInputDim("XShape");
auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size());
context->SetOutputDim(framework::GradVarName("X"), x_dims);
context->ShareLoD("XShape", framework::GradVarName("X"));
}
};
class Squeeze2GradOp : public framework::OperatorBase {
public:
using OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto dx_name = Output(framework::GradVarName("X"));
auto dout_name = Input(framework::GradVarName("Out"));
auto xshape_name = Input("XShape");
auto xshape_dims =
scope.FindVar(xshape_name)->Get<framework::LoDTensor>().dims();
auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size());
framework::AttributeMap attrs;
attrs["shape"] = framework::vectorize<int>(x_dims);
auto reshape_op = framework::OpRegistry::CreateOp(
"reshape2_grad", {{framework::GradVarName("Out"), {dout_name}},
{"Shape", {}},
{"XShape", {xshape_name}}},
{{framework::GradVarName("X"), {dx_name}}}, attrs);
reshape_op->Run(scope, place);
}
};
DECLARE_INPLACE_OP_INFERER(SequeezeInplaceInferer, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(SequeezeGradInplaceInferer,
{framework::GradVarName("Out"),
......@@ -303,17 +264,39 @@ DECLARE_INPLACE_OP_INFERER(SequeezeGradInplaceInferer,
} // namespace operators
} // namespace paddle
// Tell linker to use reshape op
USE_OP(reshape);
namespace ops = paddle::operators;
REGISTER_OPERATOR(squeeze, ops::SqueezeOp, ops::SqueezeOpMaker,
ops::SqueezeOpInferShape,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(squeeze_grad, ops::SqueezeGradOp, ops::SqueezeGradInferShape);
REGISTER_OPERATOR(squeeze_grad, ops::SqueezeGradOp);
REGISTER_OPERATOR(squeeze2, ops::Squeeze2Op, ops::Squeeze2OpMaker,
ops::Squeeze2OpInferShape, ops::Squeeze2GradOpMaker,
ops::SequeezeInplaceInferer);
ops::Squeeze2GradOpMaker, ops::SequeezeInplaceInferer);
REGISTER_OPERATOR(squeeze2_grad, ops::Squeeze2GradOp,
ops::Squeeze2GradInferShape, ops::SequeezeGradInplaceInferer);
ops::SequeezeGradInplaceInferer);
REGISTER_OP_CPU_KERNEL(
squeeze, ops::SqueezeKernel<paddle::platform::CPUDeviceContext, float>,
ops::SqueezeKernel<paddle::platform::CPUDeviceContext, double>,
ops::SqueezeKernel<paddle::platform::CPUDeviceContext, int>,
ops::SqueezeKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::SqueezeKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
squeeze_grad,
ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
squeeze2, ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext, float>,
ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext, double>,
ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext, int>,
ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
squeeze2_grad,
ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext, float>,
ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext, double>,
ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext, int>,
ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext, int64_t>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/squeeze_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
squeeze, ops::SqueezeKernel<paddle::platform::CUDADeviceContext, float>,
ops::SqueezeKernel<paddle::platform::CUDADeviceContext, double>,
ops::SqueezeKernel<paddle::platform::CUDADeviceContext, int>,
ops::SqueezeKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::SqueezeKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
squeeze_grad,
ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
squeeze2, ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, float>,
ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, double>,
ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, int>,
ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
squeeze2_grad,
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, float>,
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, double>,
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, int>,
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, int64_t>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/pooling.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class SqueezeKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *in = context.Input<framework::LoDTensor>("X");
auto *out = context.Output<framework::LoDTensor>("Out");
auto &axes = context.Attr<std::vector<int>>("axes");
auto x_dims = in->dims();
auto out_dims = GetOutputShape(axes, x_dims);
out->mutable_data(context.GetPlace(), in->type());
framework::TensorCopy(
*in, context.GetPlace(),
context.template device_context<platform::DeviceContext>(), out);
out->Resize(out_dims);
}
static framework::DDim GetOutputShape(const std::vector<int> squeeze_dims,
const framework::DDim &in_dims) {
size_t num_squeeze_dims = squeeze_dims.size();
int cnt_squeezed_dims = 0;
bool should_squeeze[9] = {false};
// Determines number of dimensions of output tensor after squeeze.
// Mark and count the dimensions need to be squeezed
if (num_squeeze_dims == 0) {
for (int idx = 0; idx < in_dims.size(); ++idx) {
if (in_dims[idx] == 1) {
should_squeeze[idx] = true;
++cnt_squeezed_dims;
}
}
} else {
for (size_t idx = 0; idx < num_squeeze_dims; ++idx) {
int current = squeeze_dims[idx] < 0 ? squeeze_dims[idx] + in_dims.size()
: squeeze_dims[idx];
// Check current index, the upper limit has beed checked in line 36.
PADDLE_ENFORCE_GE(current, 0,
"Invalid axis, the negative axis is out of range.");
PADDLE_ENFORCE_EQ(in_dims[current], 1,
"Invalid axis index, the axis that will be squeezed "
"should be equal to 1.");
if (!(should_squeeze[current])) {
++cnt_squeezed_dims;
}
should_squeeze[current] = true;
}
}
// Make output dimensions
std::vector<int64_t> output_shape(in_dims.size() - cnt_squeezed_dims, 0);
for (int in_idx = 0, out_idx = 0; in_idx < in_dims.size(); ++in_idx) {
if (!should_squeeze[in_idx]) {
output_shape[out_idx++] = in_dims[in_idx];
}
}
return framework::make_ddim(output_shape);
}
};
template <typename DeviceContext, typename T>
class SqueezeGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *d_out =
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"));
auto *d_x = ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
auto in_dims = ctx.Input<framework::LoDTensor>("X")->dims();
d_x->mutable_data(ctx.GetPlace(), d_out->type());
framework::TensorCopySync(*d_out, ctx.GetPlace(), d_x);
d_x->Resize(in_dims);
}
};
template <typename DeviceContext, typename T>
class Squeeze2Kernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *out = context.Output<framework::LoDTensor>("Out");
auto *in = context.Input<framework::LoDTensor>("X");
auto &axes = context.Attr<std::vector<int>>("axes");
auto x_dims = in->dims();
auto out_dims =
SqueezeKernel<DeviceContext, T>::GetOutputShape(axes, x_dims);
out->mutable_data(context.GetPlace(), in->type());
framework::TensorCopy(
*in, context.GetPlace(),
context.template device_context<platform::DeviceContext>(), out);
out->Resize(out_dims);
}
};
template <typename DeviceContext, typename T>
class Squeeze2GradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *d_out =
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"));
auto *d_x = ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
// auto in_dims = d_x->dims();
auto xshape_dims = ctx.Input<framework::LoDTensor>("XShape")->dims();
auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size());
d_x->mutable_data(ctx.GetPlace(), d_out->type());
framework::TensorCopySync(*d_out, ctx.GetPlace(), d_x);
d_x->Resize(x_dims);
}
};
} // namespace operators
} // namespace paddle
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
# Correct: General.
class TestSqueezeOp(OpTest):
def setUp(self):
self.op_type = "squeeze2"
self.init_test_case()
self.inputs = {"X": np.random.random(self.ori_shape).astype("float32")}
self.init_attrs()
self.outputs = {
"Out": self.inputs["X"].reshape(self.new_shape),
"XShape": np.random.random(self.ori_shape).astype("float32")
}
def test_check_output(self):
self.check_output(no_check_set=['XShape'])
def test_check_grad(self):
self.check_grad(["X"], "Out")
def init_test_case(self):
self.ori_shape = (1, 3, 1, 5)
self.axes = (0, 2)
self.new_shape = (3, 5)
def init_attrs(self):
self.attrs = {"axes": self.axes}
# Correct: There is mins axis.
class TestSqueezeOp1(TestSqueezeOp):
def init_test_case(self):
self.ori_shape = (1, 3, 1, 5)
self.axes = (0, -2)
self.new_shape = (3, 5)
# Correct: No axes input.
class TestSqueezeOp2(TestSqueezeOp):
def init_test_case(self):
self.ori_shape = (1, 3, 1, 5)
self.axes = ()
self.new_shape = (3, 5)
# Correct: Just part of axes be squeezed.
class TestSqueezeOp3(TestSqueezeOp):
def init_test_case(self):
self.ori_shape = (3, 1, 5, 1, 4, 1)
self.axes = (1, -1)
self.new_shape = (3, 5, 1, 4)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -23,17 +23,14 @@ from op_test import OpTest
# Correct: General.
class TestSqueezeOp(OpTest):
def setUp(self):
self.op_type = "squeeze2"
self.op_type = "squeeze"
self.init_test_case()
self.inputs = {"X": np.random.random(self.ori_shape).astype("float32")}
self.init_attrs()
self.outputs = {
"Out": self.inputs["X"].reshape(self.new_shape),
"XShape": np.random.random(self.ori_shape).astype("float32")
}
self.outputs = {"Out": self.inputs["X"].reshape(self.new_shape), }
def test_check_output(self):
self.check_output(no_check_set=['XShape'])
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册