提交 5f89272c 编写于 作者: C chenweihang

change the bit insert to array insert for understandability

上级 fccdc1ab
......@@ -44,39 +44,37 @@ class UnsqueezeOpInferShape : public framework::InferShapeBase {
static framework::DDim GetOutputShape(const std::vector<int> unsqz_dims,
const framework::DDim &in_dims) {
unsigned int unsqz_mask = 0;
unsigned int front = 0, back = 0;
int output_dims_size = in_dims.size();
int output_size = in_dims.size() + unsqz_dims.size();
int cur_output_size = in_dims.size();
std::vector<int64_t> output_shape(output_size, 0);
// Validity Check: rank range.
PADDLE_ENFORCE(output_size <= 6,
"The output tensor's rank should be less than 6.");
// Simulate insert by bit calc.
for (int axis : unsqz_dims) {
int cur = axis < 0 ? axis + output_dims_size + 1 : axis;
int cur = axis < 0 ? axis + cur_output_size + 1 : axis;
// Vaildity Check: the axis bound
PADDLE_ENFORCE(
cur >= 0 && cur <= output_dims_size,
cur >= 0 && cur <= cur_output_size,
"The unsqueeze dims must be within range of current rank.");
// Save the front part.
front = unsqz_mask & ((1 << cur) - 1);
// Move the back part.
back = unsqz_mask & ~((1 << cur) - 1);
back <<= 1;
// Merge two part.
back |= (1 << cur);
unsqz_mask = front | back;
// Move old axis, and insert new axis
for (int i = cur_output_size; i >= cur; --i) {
if (output_shape[i] == 1) {
// Move axis
output_shape[i + 1] = 1;
output_shape[i] = 0;
}
}
output_shape[cur] = 1;
// Add the output size.
output_dims_size++;
// Validity Check: rank range.
PADDLE_ENFORCE(output_dims_size <= 6,
"The output tensor's rank should be less than 6.");
cur_output_size++;
}
// Make output shape
std::vector<int64_t> output_shape(output_dims_size, 0);
for (int in_idx = 0, out_idx = 0; out_idx < output_dims_size; ++out_idx) {
if ((unsqz_mask & (1 << out_idx)) == 0) {
for (int in_idx = 0, out_idx = 0; out_idx < output_size; ++out_idx) {
if (output_shape[out_idx] == 0) {
output_shape[out_idx] = in_dims[in_idx++];
} else {
output_shape[out_idx] = 1;
}
}
......@@ -86,10 +84,7 @@ class UnsqueezeOpInferShape : public framework::InferShapeBase {
class UnsqueezeOp : public framework::OperatorBase {
public:
UnsqueezeOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
using OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope &scope,
......@@ -97,8 +92,6 @@ class UnsqueezeOp : public framework::OperatorBase {
auto &axes = Attr<std::vector<int>>("axes");
auto x_dims = scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims();
auto out_dims = UnsqueezeOpInferShape::GetOutputShape(axes, x_dims);
// auto out_dims =
// scope.FindVar(Output("Out"))->Get<framework::LoDTensor>().dims();
framework::AttributeMap attrs;
attrs["shape"] = framework::vectorize2int(out_dims);
......@@ -165,11 +158,7 @@ class UnsqueezeGradInferShape : public framework::InferShapeBase {
class UnsqueezeGradOp : public framework::OperatorBase {
public:
UnsqueezeGradOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
using OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope &scope,
......
......@@ -66,6 +66,14 @@ class TestUnsqueezeOp3(TestUnsqueezeOp):
self.new_shape = (1, 3, 2, 1, 1, 5)
# Correct: Reversed axes.
class TestUnsqueezeOp4(TestUnsqueezeOp):
def init_test_case(self):
self.ori_shape = (3, 2, 5)
self.axes = (3, 1, 1)
self.new_shape = (3, 1, 1, 2, 5, 1)
# Correct: Inplace.
class TestUnsqueezeOpInplace1(TestUnsqueezeOp):
def init_test_case(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册