提交 ca157793 编写于 作者: C chenweihang

rewrite, use reshape op in unsqueeze op, test passed

上级 996c157f
...@@ -265,6 +265,7 @@ op_library(recurrent_op DEPS executor) ...@@ -265,6 +265,7 @@ op_library(recurrent_op DEPS executor)
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale) op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale)
op_library(cos_sim_op DEPS cos_sim_functor) op_library(cos_sim_op DEPS cos_sim_functor)
op_library(parallel_do_op DEPS executor) op_library(parallel_do_op DEPS executor)
op_library(unsqueeze_op DEPS reshape_op)
if (WITH_GPU) if (WITH_GPU)
op_library(conv_op DEPS vol2col depthwise_conv im2col) op_library(conv_op DEPS vol2col depthwise_conv im2col)
......
...@@ -12,41 +12,35 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,41 +12,35 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/unsqueeze_op.h"
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using framework::OpKernelType; class UnsqueezeOpInferShape : public framework::InferShapeBase {
using framework::Tensor;
class UnsqueezeOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; void operator()(framework::InferShapeContext *ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of UnsqueezeOp should not be null."); "Input(X) of UnsqueezeOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of UnsqueezeOp should not be null."); "Output(Out) of UnsqueezeOp should not be null.");
const auto& axes = ctx->Attrs().Get<std::vector<int>>("axes"); const auto &axes = ctx->Attrs().Get<std::vector<int>>("axes");
PADDLE_ENFORCE(!axes.empty(), PADDLE_ENFORCE(!axes.empty(),
"The unsqueeze axes information must be set by Attr(axes)."); "The unsqueeze axes information must be set by Attr(axes).");
const auto& x_dims = ctx->GetInputDim("X"); const auto &x_dims = ctx->GetInputDim("X");
// Validity Check: input tensor dims (<6). // Validity Check: input tensor dims (<6).
PADDLE_ENFORCE(x_dims.size() < 6, PADDLE_ENFORCE(static_cast<int>(x_dims.size()) <= 6,
"Invalid dimensions, dynamic dimensions should within " "Invalid dimensions, dynamic dimensions should within "
"[0, 5] dimensions (Eigen limit)."); "[1, 6] dimensions (Eigen limit).");
// Validity Check: the range of unsqueeze aixs. // Validity Check: the range of unsqueeze aixs.
// TODO(chenweihang): Don't consider negative axis?. for (int axis : axes) {
for (unsigned int idx = 0; idx < axes.size(); ++idx) { PADDLE_ENFORCE(axis < 6,
PADDLE_ENFORCE(axes[idx] < 6,
"Invalid dimensions, input axis should within " "Invalid dimensions, input axis should within "
"[0, 5] dimensions (Eigen limit)."); "[1, 6] dimensions (Eigen limit).");
} }
auto out_dims = GetOutputShape(axes, x_dims); auto out_dims = GetOutputShape(axes, x_dims);
...@@ -54,33 +48,7 @@ class UnsqueezeOp : public framework::OperatorWithKernel { ...@@ -54,33 +48,7 @@ class UnsqueezeOp : public framework::OperatorWithKernel {
} }
static framework::DDim GetOutputShape(const std::vector<int> unsqz_dims, static framework::DDim GetOutputShape(const std::vector<int> unsqz_dims,
const framework::DDim& in_dims) { const framework::DDim &in_dims) {
/*
* STL version
* Test Error! don't know why?.
std::vector<int64_t> output_shape;
// Contruct base output shape
for(int idx = 0; idx < in_dims.size(); ++idx) {
output_shape.emplace_back(in_dims[idx]);
}
// Validity Check: output dimensions limit.
PADDLE_ENFORCE(unsqz_dims.size() + output_shape.size() < 6,
"The Attr(axes) size is too large. The output shape should "
"be less than 6 (Eigne limit).");
// Insert the unsqueeze axis in turn.
auto it = output_shape.begin();
for (int axis : unsqz_dims) {
int cur = axis < 0 ? (axis + output_shape.size() + 1)
: axis;
// Vaildity Check: the axis bound
PADDLE_ENFORCE(cur >= 0 && cur <= static_cast<int>(output_shape.size()),
"The unsqueeze dims must be within range of current
rank.");
output_shape.emplace(it + axis, 1);
}
*/
unsigned int unsqz_mask = 0; unsigned int unsqz_mask = 0;
unsigned int front = 0, back = 0; unsigned int front = 0, back = 0;
int output_dims_size = in_dims.size(); int output_dims_size = in_dims.size();
...@@ -93,17 +61,17 @@ class UnsqueezeOp : public framework::OperatorWithKernel { ...@@ -93,17 +61,17 @@ class UnsqueezeOp : public framework::OperatorWithKernel {
cur >= 0 && cur <= output_dims_size, cur >= 0 && cur <= output_dims_size,
"The unsqueeze dims must be within range of current rank."); "The unsqueeze dims must be within range of current rank.");
// Save the front part. // Save the front part.
front = unsqz_mask & ((1 << axis) - 1); front = unsqz_mask & ((1 << cur) - 1);
// Move the back part. // Move the back part.
back = unsqz_mask & ~((1 << axis) - 1); back = unsqz_mask & ~((1 << cur) - 1);
back <<= 1; back <<= 1;
// Merge two part. // Merge two part.
back |= (1 << axis); back |= (1 << cur);
unsqz_mask = front | back; unsqz_mask = front | back;
// Add the output size. // Add the output size.
output_dims_size++; output_dims_size++;
// Validity Check: rank range. // Validity Check: rank range.
PADDLE_ENFORCE(output_dims_size < 6, PADDLE_ENFORCE(output_dims_size <= 6,
"The output tensor's rank should be less than 6."); "The output tensor's rank should be less than 6.");
} }
...@@ -121,6 +89,31 @@ class UnsqueezeOp : public framework::OperatorWithKernel { ...@@ -121,6 +89,31 @@ class UnsqueezeOp : public framework::OperatorWithKernel {
} }
}; };
class UnsqueezeOp : public framework::OperatorBase {
public:
UnsqueezeOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto &axes = Attr<std::vector<int>>("axes");
auto x_dims = scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims();
auto out_dims = UnsqueezeOpInferShape::GetOutputShape(axes, x_dims);
framework::AttributeMap attrs;
attrs["shape"] = framework::vectorize2int(out_dims);
attrs["inplace"] = Attr<bool>("inplace");
// Invoke Reshape op.
auto reshape_op = framework::OpRegistry::CreateOp(
"reshape", {{"X", {Input("X")}}, {"Shape", {}}},
{{"Out", {Output("Out")}}}, attrs);
reshape_op->Run(scope, place);
}
};
class UnsqueezeOpMaker : public framework::OpProtoAndCheckerMaker { class UnsqueezeOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
...@@ -150,42 +143,49 @@ class UnsqueezeOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -150,42 +143,49 @@ class UnsqueezeOpMaker : public framework::OpProtoAndCheckerMaker {
} }
}; };
class UnsqueezeGradOp : public framework::OperatorWithKernel { class UnsqueezeGradInferShape : public framework::InferShapeBase {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; void operator()(framework::InferShapeContext *ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of UnsqueezeGradOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Output(Out@GRAD) of UnsqueezeGradOp should not be null.");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
ctx->ShareLoD("X", framework::GradVarName("X"));
} }
};
protected: class UnsqueezeGradOp : public framework::OperatorBase {
framework::OpKernelType GetExpectedKernelType( public:
const framework::ExecutionContext& ctx) const override { UnsqueezeGradOp(const std::string &type,
return framework::OpKernelType( const framework::VariableNameMap &inputs,
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()), const framework::VariableNameMap &outputs,
ctx.device_context()); const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto dx_name = Output(framework::GradVarName("X"));
auto dout_name = Input(framework::GradVarName("Out"));
auto x_dims = scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims();
framework::AttributeMap attrs;
attrs["shape"] = framework::vectorize2int(x_dims);
attrs["inplace"] = Attr<bool>("inplace");
auto reshape_op = framework::OpRegistry::CreateOp(
"reshape", {{"X", {dout_name}}, {"Shape", {}}}, {{"Out", {dx_name}}},
attrs);
reshape_op->Run(scope, place);
} }
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
// Tell linker to use reshape op.
USE_OP(reshape);
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(unsqueeze, ops::UnsqueezeOp, ops::UnsqueezeOpMaker, REGISTER_OPERATOR(unsqueeze, ops::UnsqueezeOp, ops::UnsqueezeOpMaker,
ops::UnsqueezeOpInferShape,
paddle::framework::DefaultGradOpDescMaker<true>); paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(unsqueeze_grad, ops::UnsqueezeGradOp); REGISTER_OPERATOR(unsqueeze_grad, ops::UnsqueezeGradOp,
REGISTER_OP_CPU_KERNEL( ops::UnsqueezeGradInferShape);
unsqueeze, ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, float>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, double>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
unsqueeze_grad,
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/unsqueeze_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
unsqueeze, ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, float>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, double>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
unsqueeze_grad,
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, int64_t>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class UnsqueezeKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *out = ctx.Output<framework::LoDTensor>("Out");
auto *in = ctx.Input<framework::LoDTensor>("X");
framework::DDim out_dims = out->dims();
bool inplace = ctx.Attr<bool>("inplace");
out->Resize(out_dims);
if (!inplace) {
out->mutable_data<T>(ctx.GetPlace());
framework::TensorCopySync(*in, ctx.GetPlace(), out);
out->Resize(out_dims);
} else {
out->ShareDataWith(*in);
out->Resize(out_dims);
}
}
};
template <typename DeviceContext, typename T>
class UnsqueezeGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto *d_x = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
d_x->mutable_data<T>(ctx.GetPlace());
bool inplace = ctx.Attr<bool>("inplace");
auto in_dims = d_x->dims();
if (!inplace) {
framework::TensorCopy(*d_out, ctx.GetPlace(), ctx.device_context(), d_x);
ctx.device_context().Wait();
d_x->Resize(in_dims);
} else {
d_x->ShareDataWith(*d_out);
d_x->Resize(in_dims);
}
}
};
} // namespace operators
} // namespace paddle
...@@ -27,7 +27,7 @@ class TestUnsqueezeOp(OpTest): ...@@ -27,7 +27,7 @@ class TestUnsqueezeOp(OpTest):
self.op_type = "unsqueeze" self.op_type = "unsqueeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")} self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inpalce": False} self.attrs = {"axes": axes, "inplace": False}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self): def test_check_output(self):
...@@ -37,23 +37,42 @@ class TestUnsqueezeOp(OpTest): ...@@ -37,23 +37,42 @@ class TestUnsqueezeOp(OpTest):
self.check_grad(["X"], "Out") self.check_grad(["X"], "Out")
# Correct: There is mins axis. # Correct: Single input index.
class TestUnsqueezeOp1(OpTest):
def setUp(self):
ori_shape = (3, 5)
axes = (-1, )
new_shape = (3, 5, 1)
self.op_type = "unsqueeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inplace": False}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
# Correct: Mixed input axis.
class TestUnsqueezeOp2(OpTest): class TestUnsqueezeOp2(OpTest):
def setUp(self): def setUp(self):
ori_shape = (3, 5) ori_shape = (3, 5)
axes = (0, -2) axes = (0, -1)
new_shape = (1, 3, 1, 5) new_shape = (1, 3, 5, 1)
self.op_type = "unsqueeze" self.op_type = "unsqueeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")} self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inpalce": False} self.attrs = {"axes": axes, "inplace": False}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(["X"], "Out") self.check_grad(["X"], "Out")
# Correct: There is duplicated axis. # Correct: There is duplicated axis.
...@@ -65,83 +84,84 @@ class TestUnsqueezeOp3(OpTest): ...@@ -65,83 +84,84 @@ class TestUnsqueezeOp3(OpTest):
self.op_type = "unsqueeze" self.op_type = "unsqueeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")} self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inpalce": False} self.attrs = {"axes": axes, "inplace": False}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(["X"], "Out") self.check_grad(["X"], "Out")
# Error: Output dimension is error. # Correct: Inplace.
class TestUnsqueezeOp4(OpTest): class TestUnsqueezeOpInplace1(OpTest):
def setUp(self): def setUp(self):
ori_shape = (3, 2, 5) ori_shape = (3, 5)
axes = (0, 3) axes = (0, 2)
new_shape = (1, 3, 2, 2, 5) new_shape = (1, 3, 1, 5)
self.op_type = "unsqueeze" self.op_type = "unsqueeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")} self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inpalce": False} self.attrs = {"axes": axes, "inplace": True}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(["X"], "Out") self.check_grad(["X"], "Out")
# Error: Input axes is invalid case 1. # Correct: Inplace. There is mins index.
class TestUnsqueezeOp5(OpTest): class TestUnsqueezeOpInplace2(OpTest):
def setUp(self): def setUp(self):
ori_shape = (3, 2, 5) ori_shape = (3, 5)
axes = (0, 5) axes = (0, -2)
new_shape = (1, 3, 1, 5) new_shape = (1, 3, 1, 5)
self.op_type = "unsqueeze" self.op_type = "unsqueeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")} self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inpalce": False} self.attrs = {"axes": axes, "inplace": True}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(["X"], "Out") self.check_grad(["X"], "Out")
# Error: Input axes is invalid case 2. # Correct: Inplace. There is duplicated axis.
class TestUnsqueezeOp5(OpTest): class TestUnsqueezeOpInplace3(OpTest):
def setUp(self): def setUp(self):
ori_shape = (3, 2, 5) ori_shape = (3, 2, 5)
axes = (0, 2, 10) axes = (0, 3, 3)
new_shape = (1, 3, 1, 5) new_shape = (1, 3, 2, 1, 1, 5)
self.op_type = "unsqueeze" self.op_type = "unsqueeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")} self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inpalce": False} self.attrs = {"axes": axes, "inplace": True}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(["X"], "Out") self.check_grad(["X"], "Out")
# Correct: Inplace. '''
class TestUnsqueezeOpInplace1(OpTest): # Error: Output dimension is error.
class TestUnsqueezeOp4(OpTest):
def setUp(self): def setUp(self):
ori_shape = (3, 5) ori_shape = (3, 5)
axes = (0, 2) axes = (0, 3)
new_shape = (1, 3, 1, 5) new_shape = (1, 3, 1, 1, 5)
self.op_type = "unsqueeze" self.op_type = "unsqueeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")} self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inplace": True} self.attrs = {"axes": axes, "inplace": False}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self): def test_check_output(self):
...@@ -150,25 +170,60 @@ class TestUnsqueezeOpInplace1(OpTest): ...@@ -150,25 +170,60 @@ class TestUnsqueezeOpInplace1(OpTest):
def test_check_grad(self): def test_check_grad(self):
self.check_grad(["X"], "Out") self.check_grad(["X"], "Out")
# Error: Input axis is large than output range.
# Correct: Inplace. There is duplicated axis. class TestUnsqueezeOp5(OpTest):
class TestUnsqueezeOpInplace2(OpTest):
def setUp(self): def setUp(self):
ori_shape = (3, 2, 5) ori_shape = (3, 5)
axes = (0, 3, 3) axes = (0, 4)
new_shape = (1, 3, 2, 1, 1, 5) new_shape = (1, 3, 5, 1)
self.op_type = "unsqueeze" self.op_type = "unsqueeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")} self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inpalce": True} self.attrs = {"axes": axes, "inplace": False}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(["X"], "Out") self.check_grad(["X"], "Out")
# Error: Input axes is large than Eigen limit.
class TestUnsqueezeOp6(OpTest):
def setUp(self):
ori_shape = (3, 5)
axes = (0, 2, 10)
new_shape = (1, 3, 1, 5, 1)
self.op_type = "unsqueeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inplace": False}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
# Error: Input axes size is large than Eigen limit.
class TestUnsqueezeOp7(OpTest):
def setUp(self):
ori_shape = (3, 5)
axes = (0, 2, 2, 2, 2, 2)
new_shape = (1, 3, 1, 1, 5, 1)
self.op_type = "unsqueeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inplace": False}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
'''
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册