提交 9de45e11 编写于 作者: X xzl

fixed bug when dims.size == 1, modify the variable naming, add judgement when input_grad is null

上级 35967e86
...@@ -27,26 +27,29 @@ class TransposeOp : public framework::OperatorWithKernel { ...@@ -27,26 +27,29 @@ class TransposeOp : public framework::OperatorWithKernel {
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Input"), PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Input"),
"Input(Input) should not be null"); "Input(Input) should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Output"),
"Output(Output) should not be null");
auto input_dim = ctx.Input<Tensor>("Input")->dims(); auto input_dim = ctx.Input<Tensor>("Input")->dims();
auto axis = ctx.Attr<std::vector<int>>("axis"); std::vector<int> axis = ctx.Attr<std::vector<int>>("axis");
size_t input_dim_size = input_dim.size(); size_t input_rank = input_dim.size();
size_t axis_size = axis.size(); size_t axis_size = axis.size();
PADDLE_ENFORCE_EQ(input_dim_size, axis_size, PADDLE_ENFORCE_EQ(input_rank, axis_size,
"the input tensor's dimension(%d) " "the input tensor's rank(%d) "
"should be equal to the axis's size(%d)", "should be equal to the axis's size(%d)",
input_dim_size, axis_size); input_rank, axis_size);
std::vector<int> axis_sorted(axis); std::vector<int> count(axis_size, 0);
std::sort(axis_sorted.begin(), axis_sorted.end()); for (size_t i = 0; i < axis_size; i++) {
for (size_t i = 0; i < axis_sorted.size(); i++) { PADDLE_ENFORCE(
PADDLE_ENFORCE_EQ(axis_sorted[i], static_cast<int>(i), axis[i] < static_cast<int>(axis_size) && ++count[axis[i]] == 1,
"the sorted axis should be [0, 1, ... dims - 1], " "Each element of Attribute axis should be a unique value "
"range from 0 to (dims - 1), "
"where the dims is the axis's size"); "where the dims is the axis's size");
} }
framework::DDim output_dim(input_dim); framework::DDim output_dim(input_dim);
for (size_t i = 0; i < axis.size(); i++) { for (size_t i = 0; i < axis_size; i++) {
output_dim[i] = input_dim[axis[i]]; output_dim[i] = input_dim[axis[i]];
} }
ctx.Output<framework::LoDTensor>("Output")->Resize(output_dim); ctx.Output<framework::LoDTensor>("Output")->Resize(output_dim);
...@@ -60,12 +63,12 @@ class TransposeOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -60,12 +63,12 @@ class TransposeOpMaker : public framework::OpProtoAndCheckerMaker {
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput( AddInput(
"Input", "Input",
"(Tensor)The input tensor, tensors with rank at most 7 are supported"); "(Tensor)The input tensor, tensors with rank at most 6 are supported");
AddOutput("Output", "(Tensor)The output tensor"); AddOutput("Output", "(Tensor)The output tensor");
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>(
"axis", "axis",
"(vector<int>)a list of values, and the size of the list should be " "(vector<int>)a list of values, and the size of the list should be "
"the same with the input tensor dimensions, the tensor will " "the same with the input tensor rank, the tensor will "
"permute the axes according the the values given"); "permute the axes according the the values given");
AddComment(R"DOC( AddComment(R"DOC(
The Tensor will be permuted according to the axis values given. The Tensor will be permuted according to the axis values given.
...@@ -97,18 +100,11 @@ class TransposeOpGrad : public framework::OperatorWithKernel { ...@@ -97,18 +100,11 @@ class TransposeOpGrad : public framework::OperatorWithKernel {
"Input(Input) should not be null"); "Input(Input) should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Output")), PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Output")),
"Input(Output@GRAD) should not be null"); "Input(Output@GRAD) should not be null");
auto input_dims = ctx.Input<Tensor>("Input")->dims(); auto input_dim = ctx.Input<Tensor>("Input")->dims();
auto *input_grad = auto *input_grad =
ctx.Output<framework::LoDTensor>(framework::GradVarName("Input")); ctx.Output<framework::LoDTensor>(framework::GradVarName("Input"));
auto output_grad_dims = if (input_grad) input_grad->Resize(input_dim);
ctx.Input<Tensor>(framework::GradVarName("Output"))->dims();
auto output_dims = ctx.Input<Tensor>("Output")->dims();
PADDLE_ENFORCE(output_grad_dims == output_dims,
"Output@GRAD dims must equal to Input(Input) dims");
input_grad->Resize(input_dims);
} }
}; };
......
...@@ -20,19 +20,19 @@ ...@@ -20,19 +20,19 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename T, int Dims> template <typename Place, typename T, int Rank>
void EigenTranspose(const framework::ExecutionContext& context, void EigenTranspose(const framework::ExecutionContext& context,
const framework::Tensor& in, framework::Tensor& out, const framework::Tensor& in, framework::Tensor& out,
std::vector<int> axis) { std::vector<int> axis) {
Eigen::array<int, Dims> permute; Eigen::array<int, Rank> permute;
for (int i = 0; i < Dims; i++) { for (int i = 0; i < Rank; i++) {
permute[i] = axis[i]; permute[i] = axis[i];
} }
auto in_dim = in.dims(); auto in_dim = in.dims();
auto out_dim = out.dims(); auto out_dim = out.dims();
auto eigen_in = framework::EigenTensor<T, Dims>::From(in); auto eigen_in = framework::EigenTensor<T, Rank>::From(in);
auto eigen_out = framework::EigenTensor<T, Dims>::From(out); auto eigen_out = framework::EigenTensor<T, Rank>::From(out);
auto& dev = context.GetEigenDevice<Place>(); auto& dev = context.GetEigenDevice<Place>();
eigen_out.device(dev) = eigen_in.shuffle(permute); eigen_out.device(dev) = eigen_in.shuffle(permute);
} }
...@@ -45,10 +45,11 @@ class TransposeKernel : public framework::OpKernel { ...@@ -45,10 +45,11 @@ class TransposeKernel : public framework::OpKernel {
auto* output = context.Output<framework::Tensor>("Output"); auto* output = context.Output<framework::Tensor>("Output");
output->mutable_data<T>(context.GetPlace()); output->mutable_data<T>(context.GetPlace());
auto axis = context.Attr<std::vector<int>>("axis"); std::vector<int> axis = context.Attr<std::vector<int>>("axis");
int ndims = axis.size(); int ndims = axis.size();
switch (ndims) { switch (ndims) {
case 1: case 1:
EigenTranspose<Place, T, 1>(context, *input, *output, axis);
break; break;
case 2: case 2:
EigenTranspose<Place, T, 2>(context, *input, *output, axis); EigenTranspose<Place, T, 2>(context, *input, *output, axis);
...@@ -79,39 +80,48 @@ class TransposeGradKernel : public framework::OpKernel { ...@@ -79,39 +80,48 @@ class TransposeGradKernel : public framework::OpKernel {
context.Input<framework::Tensor>(framework::GradVarName("Output")); context.Input<framework::Tensor>(framework::GradVarName("Output"));
auto* input_grad = auto* input_grad =
context.Output<framework::Tensor>(framework::GradVarName("Input")); context.Output<framework::Tensor>(framework::GradVarName("Input"));
if (input_grad) {
input_grad->mutable_data<T>(context.GetPlace()); input_grad->mutable_data<T>(context.GetPlace());
auto axis_temp = context.Attr<std::vector<int>>("axis"); std::vector<int> axis = context.Attr<std::vector<int>>("axis");
std::vector<int> axis(axis_temp); std::vector<int> reversed_axis(axis);
for (size_t i = 0; i < axis.size(); i++) { for (size_t i = 0; i < axis.size(); i++) {
axis[axis_temp[i]] = i; reversed_axis[axis[i]] = i;
} }
int ndims = axis.size(); int ndims = axis.size();
switch (ndims) { switch (ndims) {
case 1: case 1:
EigenTranspose<Place, T, 1>(context, *output_grad, *input_grad,
reversed_axis);
break; break;
case 2: case 2:
EigenTranspose<Place, T, 2>(context, *output_grad, *input_grad, axis); EigenTranspose<Place, T, 2>(context, *output_grad, *input_grad,
reversed_axis);
break; break;
case 3: case 3:
EigenTranspose<Place, T, 3>(context, *output_grad, *input_grad, axis); EigenTranspose<Place, T, 3>(context, *output_grad, *input_grad,
reversed_axis);
break; break;
case 4: case 4:
EigenTranspose<Place, T, 4>(context, *output_grad, *input_grad, axis); EigenTranspose<Place, T, 4>(context, *output_grad, *input_grad,
reversed_axis);
break; break;
case 5: case 5:
EigenTranspose<Place, T, 5>(context, *output_grad, *input_grad, axis); EigenTranspose<Place, T, 5>(context, *output_grad, *input_grad,
reversed_axis);
break; break;
case 6: case 6:
EigenTranspose<Place, T, 6>(context, *output_grad, *input_grad, axis); EigenTranspose<Place, T, 6>(context, *output_grad, *input_grad,
reversed_axis);
break; break;
default: default:
PADDLE_THROW("Tensors with rank at most 6 are supported"); PADDLE_THROW("Tensors with rank at most 6 are supported");
} }
} }
}
}; };
} // namespace operators } // namespace operators
......
...@@ -22,6 +22,12 @@ class TestTransposeOp(OpTest): ...@@ -22,6 +22,12 @@ class TestTransposeOp(OpTest):
self.axis = (1, 0) self.axis = (1, 0)
class TestCase0(TestTransposeOp):
def initTestCase(self):
self.shape = (3, )
self.axis = (0, )
class TestCase1(TestTransposeOp): class TestCase1(TestTransposeOp):
def initTestCase(self): def initTestCase(self):
self.shape = (3, 4, 5) self.shape = (3, 4, 5)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册