未验证 提交 88685255 编写于 作者: C chengduo 提交者: GitHub

Refine reshape_grad and transpose_grad (#13074)

* Add intermediate

* fix flatten/squeeze/unsqueeze

* Considering compatibility issues, we could not fix the origin op

* follow comment

* reset the shape of XShape
上级 7dd8adb5
......@@ -157,6 +157,116 @@ class FlattenGradOp : public framework::OperatorBase {
}
};
// FIXME(zcd): flatten2 adds an intermediate output(XShape) based on flatten,
// the XShape is used to carry the shape and lod of X which will be used in
// flatten_grad, in this way, the framework can reuse the memory of X
// immediately the flatten2_op is finished.
// Considering compatibility issues, we could not fix flatten2_op
class Flatten2OpInferShape : public FlattenOpInferShape {
public:
void operator()(framework::InferShapeContext *ctx) const override {
FlattenOpInferShape::operator()(ctx);
PADDLE_ENFORCE(ctx->HasOutput("XShape"),
"Output (XShape) of Flatten op should not be null.");
const auto &in_dims = ctx->GetInputDim("X");
std::vector<int64_t> xshape_dims(in_dims.size() + 1);
xshape_dims[0] = 0;
for (int i = 0; i < in_dims.size(); ++i) {
xshape_dims[i + 1] = in_dims[i];
}
ctx->SetOutputDim("XShape", framework::make_ddim(xshape_dims));
ctx->ShareLoD("X", "XShape");
}
};
class Flatten2Op : public framework::OperatorBase {
public:
using OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto &axis = Attr<int>("axis");
auto in_dims =
scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims();
const auto &out_dims = FlattenOpInferShape::GetOutputShape(axis, in_dims);
framework::AttributeMap attrs;
attrs["shape"] = out_dims;
attrs["inplace"] = false;
// Invoke Reshape Op
auto reshape_op = framework::OpRegistry::CreateOp(
"reshape2", {{"X", {Input("X")}}, {"Shape", {}}},
{{"Out", {Output("Out")}}, {"XShape", {Output("XShape")}}}, attrs);
reshape_op->Run(scope, place);
}
};
class Flatten2OpMaker : public FlattenOpMaker {
public:
void Make() override {
FlattenOpMaker::Make();
AddOutput("XShape",
"XShape is just used to store the shape and lod of X, which will "
"be used in FlattenGradOp.")
.AsIntermediate();
}
};
class Flatten2GradOpMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
std::unique_ptr<framework::OpDesc> Apply() const override {
auto *grad_op = new framework::OpDesc();
grad_op->SetType("flatten2_grad");
grad_op->SetInput("XShape", Output("XShape"));
grad_op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
grad_op->SetAttrMap(Attrs());
return std::unique_ptr<framework::OpDesc>(grad_op);
}
};
class Flatten2GradInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE(context->HasInput("XShape"),
"Input(XShape) shouldn't be null.");
PADDLE_ENFORCE(context->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) shouldn't be null.");
auto xshape_dims = context->GetInputDim("XShape");
auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size());
context->SetOutputDim(framework::GradVarName("X"), x_dims);
context->ShareLoD("XShape", framework::GradVarName("X"));
}
};
class Flatten2GradOp : public framework::OperatorBase {
public:
using OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto dx_name = Output(framework::GradVarName("X"));
auto dout_name = Input(framework::GradVarName("Out"));
auto xshape_name = Input("XShape");
auto xshape_dims =
scope.FindVar(xshape_name)->Get<framework::LoDTensor>().dims();
auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size());
framework::AttributeMap attrs;
attrs["shape"] = framework::vectorize2int(x_dims);
attrs["inplace"] = false;
auto reshape_op = framework::OpRegistry::CreateOp(
"reshape2", {{"X", {dout_name}}, {"Shape", {}}},
{{"Out", {dx_name}}, {"XShape", {xshape_name}}}, attrs);
reshape_op->Run(scope, place);
}
};
} // namespace operators
} // namespace paddle
......@@ -167,3 +277,8 @@ REGISTER_OPERATOR(flatten, ops::FlattenOp, ops::FlattenOpMaker,
ops::FlattenOpInferShape,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(flatten_grad, ops::FlattenGradOp, ops::FlattenGradInferShape);
REGISTER_OPERATOR(flatten2, ops::Flatten2Op, ops::Flatten2OpMaker,
ops::Flatten2OpInferShape, ops::Flatten2GradOpMaker);
REGISTER_OPERATOR(flatten2_grad, ops::Flatten2GradOp,
ops::Flatten2GradInferShape);
......@@ -246,6 +246,88 @@ class ReshapeGradKernel {
}
};
// FIXME(zcd): reshape2 adds an intermediate output(XShape) based on reshape,
// the XShape is used to carry the shape and lod of X which will be used in
// reshape_grad, in this way, the framework can reuse the memory of X
// immediately the reshape_op is finished.
// Considering compatibility issues, we could not fix reshape_op
class Reshape2Op : public ReshapeOp {
public:
Reshape2Op(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: ReshapeOp(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override {
ReshapeOp::InferShape(ctx);
PADDLE_ENFORCE(ctx->HasOutput("XShape"),
"Output(XShape) of ReshapeOp should not be null.");
const auto &x_dims = ctx->GetInputDim("X");
std::vector<int64_t> xshape_dims(x_dims.size() + 1);
xshape_dims[0] = 0;
for (int i = 0; i < x_dims.size(); ++i) {
xshape_dims[i + 1] = x_dims[i];
}
ctx->SetOutputDim("XShape", framework::make_ddim(xshape_dims));
ctx->ShareLoD("X", /*->*/ "XShape");
}
};
class Reshape2OpMaker : public ReshapeOpMaker {
public:
void Make() override {
ReshapeOpMaker::Make();
AddOutput("XShape",
"XShape is just used to store the shape and lod of X, which will "
"be used in FlattenGradOp.")
.AsIntermediate();
}
};
class Reshape2GradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
std::unique_ptr<framework::OpDesc> Apply() const override {
auto *grad_op = new framework::OpDesc();
grad_op->SetType("reshape2_grad");
grad_op->SetInput("XShape", Output("XShape"));
grad_op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
grad_op->SetAttrMap(Attrs());
return std::unique_ptr<framework::OpDesc>(grad_op);
}
};
class Reshape2GradOp : public framework::OperatorWithKernel {
public:
Reshape2GradOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("XShape"), "Input(XShape) shouldn't be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) shouldn't be null.");
auto xshape_dims = ctx->GetInputDim("XShape");
auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size());
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
ctx->ShareLoD("XShape", framework::GradVarName("X"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))
->type()),
ctx.device_context());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
......@@ -261,6 +343,17 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel,
ops::ReshapeGradKernel, int64_t,
ops::ReshapeGradKernel);
REGISTER_OPERATOR(reshape2, ops::Reshape2Op, ops::Reshape2OpMaker,
ops::Reshape2GradMaker);
REGISTER_OPERATOR(reshape2_grad, ops::Reshape2GradOp);
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double,
ops::ReshapeKernel, int, ops::ReshapeKernel,
int64_t, ops::ReshapeKernel);
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2_grad, float, ops::ReshapeGradKernel,
double, ops::ReshapeGradKernel, int,
ops::ReshapeGradKernel, int64_t,
ops::ReshapeGradKernel);
#ifdef PADDLE_WITH_CUDA
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double,
ops::ReshapeKernel, int, ops::ReshapeKernel,
......@@ -269,4 +362,11 @@ REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel,
double, ops::ReshapeGradKernel, int,
ops::ReshapeGradKernel, int64_t,
ops::ReshapeGradKernel);
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double,
ops::ReshapeKernel, int, ops::ReshapeKernel,
int64_t, ops::ReshapeKernel);
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2_grad, float, ops::ReshapeGradKernel,
double, ops::ReshapeGradKernel, int,
ops::ReshapeGradKernel, int64_t,
ops::ReshapeGradKernel);
#endif
......@@ -126,15 +126,15 @@ class SqueezeOpMaker : public framework::OpProtoAndCheckerMaker {
.SetDefault({});
AddComment(R"DOC(
Squeeze Operator.
Remove single-dimensional entries from the shape of a tensor.
Takes a parameter axes with a list of axes to squeeze.
If axes is not provided, all the single dimensions will be removed from the shape.
Remove single-dimensional entries from the shape of a tensor.
Takes a parameter axes with a list of axes to squeeze.
If axes is not provided, all the single dimensions will be removed from the shape.
If an axis is selected with shape entry not equal to one, an error is raised.
Examples:
Case 1:
Given
Given
X.shape = (1, 3, 1, 5)
and
axes = [0]
......@@ -144,7 +144,7 @@ class SqueezeOpMaker : public framework::OpProtoAndCheckerMaker {
Case 2:
Given
X.shape = (1, 3, 1, 5)
and
and
axes = []
we get:
Out.shape = (3, 5)
......@@ -181,6 +181,113 @@ class SqueezeGradOp : public framework::OperatorBase {
}
};
// FIXME(zcd): squeeze2 adds an intermediate output(XShape) based on squeeze,
// the XShape is used to carry the shape and lod of X which will be used in
// squeeze_grad, in this way, the framework can reuse the memory of X
// immediately the squeeze2_op is finished.
// Considering compatibility issues, we could not fix squeeze2_op
class Squeeze2OpMaker : public SqueezeOpMaker {
public:
void Make() override {
SqueezeOpMaker::Make();
AddOutput("XShape",
"XShape is just used to store the shape and lod of X, which will "
"be used in SqueezeGradOp.")
.AsIntermediate();
}
};
class Squeeze2OpInferShape : public SqueezeOpInferShape {
public:
void operator()(framework::InferShapeContext *ctx) const override {
SqueezeOpInferShape::operator()(ctx);
PADDLE_ENFORCE(ctx->HasOutput("XShape"),
"Output(XShape) of Squeeze operator should not be null.");
const auto &x_dims = ctx->GetInputDim("X");
std::vector<int64_t> xshape_dims(x_dims.size() + 1);
xshape_dims[0] = 0;
for (int i = 0; i < x_dims.size(); ++i) {
xshape_dims[i + 1] = x_dims[i];
}
ctx->SetOutputDim("XShape", framework::make_ddim(xshape_dims));
ctx->ShareLoD("X", /*->*/ "XShape");
}
};
class Squeeze2Op : public framework::OperatorBase {
public:
using OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto &axes = Attr<std::vector<int>>("axes");
auto x_dims = scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims();
auto out_dims = Squeeze2OpInferShape::GetOutputShape(axes, x_dims);
framework::AttributeMap attrs;
attrs["shape"] = framework::vectorize2int(out_dims);
// Invoke Reshape Op
auto reshape_op = framework::OpRegistry::CreateOp(
"reshape2", {{"X", {Input("X")}}, {"Shape", {}}},
{{"Out", {Output("Out")}}, {"XShape", {Output("XShape")}}}, attrs);
reshape_op->Run(scope, place);
}
};
class Squeeze2GradOpMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
std::unique_ptr<framework::OpDesc> Apply() const override {
auto *grad_op = new framework::OpDesc();
grad_op->SetType("squeeze2_grad");
grad_op->SetInput("XShape", Output("XShape"));
grad_op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
grad_op->SetAttrMap(Attrs());
return std::unique_ptr<framework::OpDesc>(grad_op);
}
};
class Squeeze2GradInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE(context->HasInput("XShape"),
"Input(XShape) shouldn't be null.");
PADDLE_ENFORCE(context->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) shouldn't be null.");
auto xshape_dims = context->GetInputDim("XShape");
auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size());
context->SetOutputDim(framework::GradVarName("X"), x_dims);
context->ShareLoD("XShape", framework::GradVarName("X"));
}
};
class Squeeze2GradOp : public framework::OperatorBase {
public:
using OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto dx_name = Output(framework::GradVarName("X"));
auto dout_name = Input(framework::GradVarName("Out"));
auto xshape_name = Input("XShape");
auto xshape_dims =
scope.FindVar(xshape_name)->Get<framework::LoDTensor>().dims();
auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size());
framework::AttributeMap attrs;
attrs["shape"] = framework::vectorize2int(x_dims);
auto reshape_op = framework::OpRegistry::CreateOp(
"reshape2", {{"X", {dout_name}}, {"Shape", {}}},
{{"Out", {dx_name}}, {"XShape", {xshape_name}}}, attrs);
reshape_op->Run(scope, place);
}
};
} // namespace operators
} // namespace paddle
......@@ -192,3 +299,8 @@ REGISTER_OPERATOR(squeeze, ops::SqueezeOp, ops::SqueezeOpMaker,
ops::SqueezeOpInferShape,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(squeeze_grad, ops::SqueezeGradOp, ops::SqueezeGradInferShape);
REGISTER_OPERATOR(squeeze2, ops::Squeeze2Op, ops::Squeeze2OpMaker,
ops::Squeeze2OpInferShape, ops::Squeeze2GradOpMaker);
REGISTER_OPERATOR(squeeze2_grad, ops::Squeeze2GradOp,
ops::Squeeze2GradInferShape);
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/transpose_op.h"
#include <string>
#include <vector>
namespace paddle {
......@@ -24,7 +25,7 @@ class TransposeOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null");
auto x_dims = ctx->GetInputDim("X");
......@@ -90,7 +91,7 @@ The behavior of this operator is similar to how `numpy.transpose` works.
2 &5
\end{pmatrix}$$
- Given a input tensor with shape $(N, C, H, W)$ and the `axes` is
- Given a input tensor with shape $(N, C, H, W)$ and the `axes` is
$[0, 2, 3, 1]$, then shape of the output tensor will be: $(N, H, W, C)$.
)DOC");
......@@ -101,7 +102,7 @@ class TransposeOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
......@@ -113,6 +114,93 @@ class TransposeOpGrad : public framework::OperatorWithKernel {
}
};
// FIXME(zcd): transpose2 adds an intermediate output(XShape) based on
// transpose, the XShape is used to carry the shape and lod of X which
// will be used in transpose_grad, in this way, the framework can reuse
// the memory of X immediately the transpose2_op is finished.
// Considering compatibility issues, we could not fix transpose2_op
class Transpose2Op : public TransposeOp {
public:
Transpose2Op(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: TransposeOp(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override {
TransposeOp::InferShape(ctx);
PADDLE_ENFORCE(ctx->HasOutput("XShape"),
"Output(XShape) should not be null");
const auto &in_dims = ctx->GetInputDim("X");
std::vector<int64_t> x_shape_dim(in_dims.size() + 1);
x_shape_dim[0] = 0;
for (int i = 0; i < in_dims.size(); ++i) {
x_shape_dim[i + 1] = in_dims[i];
}
ctx->SetOutputDim("XShape", framework::make_ddim(x_shape_dim));
ctx->ShareLoD("X", /*->*/ "XShape");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
ctx.device_context());
}
};
class Transpose2OpMaker : public TransposeOpMaker {
public:
void Make() override {
TransposeOpMaker::Make();
AddOutput("XShape", "(Tensor)The output tensor.").AsIntermediate();
}
};
class Transpose2GradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
std::unique_ptr<framework::OpDesc> Apply() const override {
auto *grad_op = new framework::OpDesc();
grad_op->SetType("transpose2_grad");
grad_op->SetInput("XShape", Output("XShape"));
grad_op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
grad_op->SetAttrMap(Attrs());
return std::unique_ptr<framework::OpDesc>(grad_op);
}
};
class Transpose2OpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("XShape"), "Input(XShape) should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
if (ctx->HasOutput(framework::GradVarName("X"))) {
auto xshape_dim = ctx->GetInputDim("XShape");
auto x_shape_dim =
framework::slice_ddim(xshape_dim, 1, xshape_dim.size());
ctx->SetOutputDim(framework::GradVarName("X"), x_shape_dim);
ctx->ShareLoD("XShape", framework::GradVarName("X"));
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))
->type()),
ctx.device_context());
}
};
} // namespace operators
} // namespace paddle
......@@ -120,8 +208,20 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(transpose, ops::TransposeOp, ops::TransposeOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(transpose_grad, ops::TransposeOpGrad);
REGISTER_OP_CPU_KERNEL(
transpose, ops::TransposeKernel<paddle::platform::CPUDeviceContext, float>);
REGISTER_OP_CPU_KERNEL(
transpose_grad,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, float>);
REGISTER_OPERATOR(transpose2, ops::Transpose2Op, ops::Transpose2OpMaker,
ops::Transpose2GradMaker);
REGISTER_OPERATOR(transpose2_grad, ops::Transpose2OpGrad);
REGISTER_OP_CPU_KERNEL(
transpose2,
ops::TransposeKernel<paddle::platform::CPUDeviceContext, float>);
REGISTER_OP_CPU_KERNEL(
transpose2_grad,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, float>);
......@@ -21,3 +21,10 @@ REGISTER_OP_CUDA_KERNEL(
REGISTER_OP_CUDA_KERNEL(
transpose_grad,
ops::TransposeGradKernel<paddle::platform::CUDADeviceContext, float>);
REGISTER_OP_CUDA_KERNEL(
transpose2,
ops::TransposeKernel<paddle::platform::CUDADeviceContext, float>);
REGISTER_OP_CUDA_KERNEL(
transpose2_grad,
ops::TransposeGradKernel<paddle::platform::CUDADeviceContext, float>);
......@@ -127,13 +127,13 @@ class UnsqueezeOpMaker : public framework::OpProtoAndCheckerMaker {
});
AddComment(R"DOC(
Unsqueeze Operator.
Insert single-dimensional entries to the shape of a tensor.
Takes one required argument axes, a list of dimensions that will be inserted.
Dimension indices in axes are as seen in the output tensor.
For example:
Given a tensor such that tensor with shape [3, 4, 5],
Insert single-dimensional entries to the shape of a tensor.
Takes one required argument axes, a list of dimensions that will be inserted.
Dimension indices in axes are as seen in the output tensor.
For example:
Given a tensor such that tensor with shape [3, 4, 5],
then Unsqueeze(tensor, axes=[0, 4]) has shape [1, 3, 4, 5, 1]
)DOC");
}
......@@ -168,6 +168,112 @@ class UnsqueezeGradOp : public framework::OperatorBase {
}
};
// FIXME(zcd): unsqueeze2 adds an intermediate output(XShape) based on
// unsqueeze, the XShape is used to carry the shape and lod of X which
// will be used in unsqueeze_grad, in this way, the framework can reuse
// the memory of X immediately the unsqueeze2_op is finished.
// Considering compatibility issues, we could not fix unsqueeze2_op
class Unsqueeze2OpInferShape : public UnsqueezeOpInferShape {
public:
void operator()(framework::InferShapeContext *ctx) const override {
UnsqueezeOpInferShape::operator()(ctx);
PADDLE_ENFORCE(ctx->HasOutput("XShape"),
"Output(XShape) of Unsqueeze operator should not be null.");
const auto &x_dims = ctx->GetInputDim("X");
std::vector<int64_t> xshape_dims(x_dims.size() + 1);
xshape_dims[0] = 0;
for (int i = 0; i < x_dims.size(); ++i) {
xshape_dims[i + 1] = x_dims[i];
}
ctx->SetOutputDim("XShape", framework::make_ddim(xshape_dims));
ctx->ShareLoD("X", /*->*/ "XShape");
}
};
class Unsqueeze2OpMaker : public UnsqueezeOpMaker {
public:
void Make() override {
UnsqueezeOpMaker::Make();
AddOutput("XShape",
"XShape is just used to store the shape and lod of X, which will "
"be used in UnsqueezeGradOp.")
.AsIntermediate();
}
};
class Unsqueeze2Op : public framework::OperatorBase {
public:
using OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto &axes = Attr<std::vector<int>>("axes");
auto x_dims = scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims();
auto out_dims = Unsqueeze2OpInferShape::GetOutputShape(axes, x_dims);
framework::AttributeMap attrs;
attrs["shape"] = framework::vectorize2int(out_dims);
// Invoke Reshape op.
auto reshape_op = framework::OpRegistry::CreateOp(
"reshape2", {{"X", {Input("X")}}, {"Shape", {}}},
{{"Out", {Output("Out")}}, {"XShape", {Output("XShape")}}}, attrs);
reshape_op->Run(scope, place);
}
};
class Unsqueeze2GradOpMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
std::unique_ptr<framework::OpDesc> Apply() const override {
auto *grad_op = new framework::OpDesc();
grad_op->SetType("unsqueeze2_grad");
grad_op->SetInput("XShape", Output("XShape"));
grad_op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
grad_op->SetAttrMap(Attrs());
return std::unique_ptr<framework::OpDesc>(grad_op);
}
};
class Unsqueeze2GradInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE(context->HasInput("XShape"),
"Input(XShape) shouldn't be null.");
PADDLE_ENFORCE(context->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) shouldn't be null.");
auto xshape_dims = context->GetInputDim("XShape");
auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size());
context->SetOutputDim(framework::GradVarName("X"), x_dims);
context->ShareLoD("XShape", framework::GradVarName("X"));
}
};
class Unsqueeze2GradOp : public framework::OperatorBase {
public:
using OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto dx_name = Output(framework::GradVarName("X"));
auto dout_name = Input(framework::GradVarName("Out"));
auto xshape_name = Input("XShape");
auto xshape_dims =
scope.FindVar(xshape_name)->Get<framework::LoDTensor>().dims();
auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size());
framework::AttributeMap attrs;
attrs["shape"] = framework::vectorize2int(x_dims);
auto reshape_op = framework::OpRegistry::CreateOp(
"reshape2", {{"X", {dout_name}}, {"Shape", {}}},
{{"Out", {dx_name}}, {"XShape", {xshape_name}}}, attrs);
reshape_op->Run(scope, place);
}
};
} // namespace operators
} // namespace paddle
......@@ -180,3 +286,8 @@ REGISTER_OPERATOR(unsqueeze, ops::UnsqueezeOp, ops::UnsqueezeOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(unsqueeze_grad, ops::UnsqueezeGradOp,
ops::UnsqueezeGradInferShape);
REGISTER_OPERATOR(unsqueeze2, ops::Unsqueeze2Op, ops::Unsqueeze2OpMaker,
ops::Unsqueeze2OpInferShape, ops::Unsqueeze2GradOpMaker);
REGISTER_OPERATOR(unsqueeze2_grad, ops::Unsqueeze2GradOp,
ops::Unsqueeze2GradInferShape);
......@@ -4025,10 +4025,12 @@ def transpose(x, perm, name=None):
helper = LayerHelper('transpose', **locals())
out = helper.create_tmp_variable(x.dtype)
x_shape = helper.create_tmp_variable(x.dtype)
helper.append_op(
type='transpose',
type='transpose2',
inputs={'X': [x]},
outputs={'Out': [out]},
outputs={'Out': [out],
'XShape': [x_shape]},
attrs={'axis': perm})
return out
......@@ -4520,13 +4522,15 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=True, name=None):
"Each dimension size given in shape must not be negtive "
"except one unknown dimension.")
helper = LayerHelper("reshape", **locals())
helper = LayerHelper("reshape2", **locals())
out = helper.create_tmp_variable(dtype=x.dtype)
x_shape = helper.create_tmp_variable(dtype=x.dtype)
helper.append_op(
type="reshape",
type="reshape2",
inputs=inputs,
attrs={"shape": shape},
outputs={"Out": out})
outputs={"Out": out,
"XShape": x_shape})
return helper.append_activation(out)
......@@ -4570,11 +4574,13 @@ def squeeze(input, axes, name=None):
"""
helper = LayerHelper("squeeze", **locals())
out = helper.create_tmp_variable(dtype=input.dtype)
x_shape = helper.create_tmp_variable(dtype=input.dtype)
helper.append_op(
type="squeeze",
type="squeeze2",
inputs={"X": input},
attrs={"axes": axes},
outputs={"Out": out})
outputs={"Out": out,
"XShape": x_shape})
return out
......@@ -4605,11 +4611,13 @@ def unsqueeze(input, axes, name=None):
"""
helper = LayerHelper("unsqueeze", **locals())
out = helper.create_tmp_variable(dtype=input.dtype)
x_shape = helper.create_tmp_variable(dtype=input.dtype)
helper.append_op(
type="unsqueeze",
type="unsqueeze2",
inputs={"X": input},
attrs={"axes": axes},
outputs={"Out": out})
outputs={"Out": out,
"XShape": x_shape})
return out
......@@ -5811,10 +5819,12 @@ def flatten(x, axis=1, name=None):
raise ValueError("The axis should be a int, and in range [0, rank(x)]")
out = helper.create_tmp_variable(x.dtype)
x_shape = helper.create_tmp_variable(x.dtype)
helper.append_op(
type='flatten',
type='flatten2',
inputs={"X": x},
outputs={'Out': out},
outputs={'Out': out,
'XShape': x_shape},
attrs={"axis": axis})
return out
......
......@@ -249,7 +249,7 @@ class OpTest(unittest.TestCase):
outs, _ = self._calc_output(place)
return outs
def _calc_output(self, place, parallel=False):
def _calc_output(self, place, parallel=False, no_check_set=None):
program = Program()
block = program.global_block()
......@@ -273,6 +273,8 @@ class OpTest(unittest.TestCase):
# if not, fill the fetch_list by the user configured outputs in test.
if len(fetch_list) == 0:
for var_name, var in six.iteritems(outputs):
if no_check_set is not None and var_name in no_check_set:
continue
if isinstance(var, list):
for v in var:
fetch_list.append(v)
......@@ -291,11 +293,17 @@ class OpTest(unittest.TestCase):
return_numpy=False)
return outs, fetch_list
def check_output_with_place(self, place, atol, equal_nan=False):
outs, fetch_list = self._calc_output(place)
def check_output_with_place(self,
place,
atol,
no_check_set=None,
equal_nan=False):
outs, fetch_list = self._calc_output(place, no_check_set=no_check_set)
for out_name, out_dup in Operator.get_op_outputs(self.op_type):
if out_name not in self.outputs:
continue
if no_check_set is not None and out_name in no_check_set:
continue
def find_actual(target_name, fetch_list):
found = [
......@@ -360,10 +368,10 @@ class OpTest(unittest.TestCase):
places.append(core.CUDAPlace(0))
return places
def check_output(self, atol=1e-5, equal_nan=False):
def check_output(self, atol=1e-5, no_check_set=None, equal_nan=False):
places = self._get_places()
for place in places:
self.check_output_with_place(place, atol, equal_nan)
self.check_output_with_place(place, atol, no_check_set, equal_nan)
def check_output_customized(self, checker):
places = self._get_places()
......
......@@ -22,14 +22,17 @@ from op_test import OpTest
class TestFlattenOp(OpTest):
def setUp(self):
self.op_type = "flatten"
self.op_type = "flatten2"
self.init_test_case()
self.inputs = {"X": np.random.random(self.in_shape).astype("float32")}
self.init_attrs()
self.outputs = {"Out": self.inputs["X"].reshape(self.new_shape)}
self.outputs = {
"Out": self.inputs["X"].reshape(self.new_shape),
"XShape": np.random.random(self.in_shape).astype("float32")
}
def test_check_output(self):
self.check_output()
self.check_output(no_check_set=["XShape"])
def test_check_grad(self):
self.check_grad(["X"], "Out")
......
......@@ -22,106 +22,39 @@ from op_test import OpTest
class TestReshapeOp(OpTest):
def setUp(self):
ori_shape = (2, 25)
new_shape = (5, 10)
self.op_type = "reshape"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"shape": new_shape}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
class TestReshapeOpDimInfer1(OpTest):
def setUp(self):
ori_shape = (5, 10)
new_shape = (5, -1, 5)
self.op_type = "reshape"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"shape": new_shape}
self.outputs = {"Out": self.inputs["X"].reshape(self.attrs["shape"])}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
class TestReshapeOpDimInfer2(OpTest):
def setUp(self):
ori_shape = (2, 2, 6)
new_shape = (2, 0, 3, -1)
infered_shape = (2, 2, 3, -1)
self.op_type = "reshape"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"shape": new_shape}
self.outputs = {"Out": self.inputs["X"].reshape(infered_shape)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
class TestReshapeOpInplace(OpTest):
def setUp(self):
ori_shape = (2, 25)
new_shape = (5, 10)
self.op_type = "reshape"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"shape": new_shape}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
class TestReshapeOpDimInferInplace1(OpTest):
def setUp(self):
ori_shape = (5, 10)
new_shape = (5, -1, 5)
self.init_data()
self.op_type = "reshape2"
self.inputs = {"X": np.random.random(self.ori_shape).astype("float32")}
self.attrs = {"shape": self.new_shape}
self.outputs = {
"Out": self.inputs["X"].reshape(self.infered_shape),
'XShape': np.random.random(self.ori_shape).astype("float32")
}
self.op_type = "reshape"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"shape": new_shape}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def init_data(self):
self.ori_shape = (2, 25)
self.new_shape = (5, 10)
self.infered_shape = (5, 10)
def test_check_output(self):
self.check_output()
self.check_output(no_check_set=['XShape'])
def test_check_grad(self):
self.check_grad(["X"], "Out")
class TestReshapeOpDimInferInplace2(OpTest):
def setUp(self):
ori_shape = (2, 2, 6)
new_shape = (2, 0, 3, -1)
infered_shape = (2, 2, 3, -1)
self.op_type = "reshape"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"shape": new_shape}
self.outputs = {"Out": self.inputs["X"].reshape(infered_shape)}
class TestReshapeOpDimInfer1(TestReshapeOp):
def init_data(self):
self.ori_shape = (5, 10)
self.new_shape = (5, -1, 5)
self.infered_shape = (5, -1, 5)
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
class TestReshapeOpDimInfer2(TestReshapeOp):
def init_data(self):
self.ori_shape = (2, 2, 6)
self.new_shape = (2, 0, 3, -1)
self.infered_shape = (2, 2, 3, -1)
class TestReshapeOpWithInputShape(OpTest):
......@@ -130,20 +63,23 @@ class TestReshapeOpWithInputShape(OpTest):
new_shape = (0, -1, 5)
actual_shape = (2, 3, 5)
self.op_type = "reshape"
self.op_type = "reshape2"
self.inputs = {
"X": np.random.random(ori_shape).astype("float32"),
"Shape": np.array(
actual_shape, dtype="int32")
}
self.attrs = {"shape": new_shape}
self.outputs = {"Out": self.inputs["X"].reshape(actual_shape)}
self.outputs = {
"Out": self.inputs["X"].reshape(actual_shape),
'XShape': np.random.random(ori_shape).astype("float32")
}
def test_check_output(self):
self.check_output()
self.check_output(no_check_set=['XShape'])
def test_check_grad(self):
self.check_grad(["X"], "Out")
self.check_grad(["X"], "Out", sum_outputs=["Out"])
if __name__ == "__main__":
......
......@@ -23,14 +23,17 @@ from op_test import OpTest
# Correct: General.
class TestSqueezeOp(OpTest):
def setUp(self):
self.op_type = "squeeze"
self.op_type = "squeeze2"
self.init_test_case()
self.inputs = {"X": np.random.random(self.ori_shape).astype("float32")}
self.init_attrs()
self.outputs = {"Out": self.inputs["X"].reshape(self.new_shape)}
self.outputs = {
"Out": self.inputs["X"].reshape(self.new_shape),
"XShape": np.random.random(self.ori_shape).astype("float32")
}
def test_check_output(self):
self.check_output()
self.check_output(no_check_set=['XShape'])
def test_check_grad(self):
self.check_grad(["X"], "Out")
......
......@@ -22,16 +22,19 @@ from op_test import OpTest
class TestTransposeOp(OpTest):
def setUp(self):
self.initTestCase()
self.op_type = "transpose"
self.op_type = "transpose2"
self.inputs = {'X': np.random.random(self.shape).astype("float32")}
self.attrs = {'axis': list(self.axis)}
self.outputs = {'Out': self.inputs['X'].transpose(self.axis)}
self.outputs = {
'XShape': np.random.random(self.shape).astype("float32"),
'Out': self.inputs['X'].transpose(self.axis)
}
def test_check_output(self):
self.check_output()
self.check_output(no_check_set=['XShape'])
def test_check_grad(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', sum_outputs=['Out'])
def initTestCase(self):
self.shape = (3, 4)
......
......@@ -24,13 +24,16 @@ from op_test import OpTest
class TestUnsqueezeOp(OpTest):
def setUp(self):
self.init_test_case()
self.op_type = "unsqueeze"
self.op_type = "unsqueeze2"
self.inputs = {"X": np.random.random(self.ori_shape).astype("float32")}
self.init_attrs()
self.outputs = {"Out": self.inputs["X"].reshape(self.new_shape)}
self.outputs = {
"Out": self.inputs["X"].reshape(self.new_shape),
"XShape": np.random.random(self.ori_shape).astype("float32")
}
def test_check_output(self):
self.check_output()
self.check_output(no_check_set=["XShape"])
def test_check_grad(self):
self.check_grad(["X"], "Out")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册