未验证 提交 88685255 编写于 作者: C chengduo 提交者: GitHub

Refine reshape_grad and transpose_grad (#13074)

* Add intermediate

* fix flatten/squeeze/unsqueeze

* Considering compatibility issues, we could not fix the origin op

* follow comment

* reset the shape of XShape
上级 7dd8adb5
...@@ -157,6 +157,116 @@ class FlattenGradOp : public framework::OperatorBase { ...@@ -157,6 +157,116 @@ class FlattenGradOp : public framework::OperatorBase {
} }
}; };
// FIXME(zcd): flatten2 adds an intermediate output(XShape) based on flatten,
// the XShape is used to carry the shape and lod of X which will be used in
// flatten_grad, in this way, the framework can reuse the memory of X
// immediately the flatten2_op is finished.
// Considering compatibility issues, we could not fix flatten2_op
class Flatten2OpInferShape : public FlattenOpInferShape {
public:
void operator()(framework::InferShapeContext *ctx) const override {
FlattenOpInferShape::operator()(ctx);
PADDLE_ENFORCE(ctx->HasOutput("XShape"),
"Output (XShape) of Flatten op should not be null.");
const auto &in_dims = ctx->GetInputDim("X");
std::vector<int64_t> xshape_dims(in_dims.size() + 1);
xshape_dims[0] = 0;
for (int i = 0; i < in_dims.size(); ++i) {
xshape_dims[i + 1] = in_dims[i];
}
ctx->SetOutputDim("XShape", framework::make_ddim(xshape_dims));
ctx->ShareLoD("X", "XShape");
}
};
class Flatten2Op : public framework::OperatorBase {
public:
using OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto &axis = Attr<int>("axis");
auto in_dims =
scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims();
const auto &out_dims = FlattenOpInferShape::GetOutputShape(axis, in_dims);
framework::AttributeMap attrs;
attrs["shape"] = out_dims;
attrs["inplace"] = false;
// Invoke Reshape Op
auto reshape_op = framework::OpRegistry::CreateOp(
"reshape2", {{"X", {Input("X")}}, {"Shape", {}}},
{{"Out", {Output("Out")}}, {"XShape", {Output("XShape")}}}, attrs);
reshape_op->Run(scope, place);
}
};
class Flatten2OpMaker : public FlattenOpMaker {
public:
void Make() override {
FlattenOpMaker::Make();
AddOutput("XShape",
"XShape is just used to store the shape and lod of X, which will "
"be used in FlattenGradOp.")
.AsIntermediate();
}
};
class Flatten2GradOpMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
std::unique_ptr<framework::OpDesc> Apply() const override {
auto *grad_op = new framework::OpDesc();
grad_op->SetType("flatten2_grad");
grad_op->SetInput("XShape", Output("XShape"));
grad_op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
grad_op->SetAttrMap(Attrs());
return std::unique_ptr<framework::OpDesc>(grad_op);
}
};
class Flatten2GradInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE(context->HasInput("XShape"),
"Input(XShape) shouldn't be null.");
PADDLE_ENFORCE(context->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) shouldn't be null.");
auto xshape_dims = context->GetInputDim("XShape");
auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size());
context->SetOutputDim(framework::GradVarName("X"), x_dims);
context->ShareLoD("XShape", framework::GradVarName("X"));
}
};
class Flatten2GradOp : public framework::OperatorBase {
public:
using OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto dx_name = Output(framework::GradVarName("X"));
auto dout_name = Input(framework::GradVarName("Out"));
auto xshape_name = Input("XShape");
auto xshape_dims =
scope.FindVar(xshape_name)->Get<framework::LoDTensor>().dims();
auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size());
framework::AttributeMap attrs;
attrs["shape"] = framework::vectorize2int(x_dims);
attrs["inplace"] = false;
auto reshape_op = framework::OpRegistry::CreateOp(
"reshape2", {{"X", {dout_name}}, {"Shape", {}}},
{{"Out", {dx_name}}, {"XShape", {xshape_name}}}, attrs);
reshape_op->Run(scope, place);
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -167,3 +277,8 @@ REGISTER_OPERATOR(flatten, ops::FlattenOp, ops::FlattenOpMaker, ...@@ -167,3 +277,8 @@ REGISTER_OPERATOR(flatten, ops::FlattenOp, ops::FlattenOpMaker,
ops::FlattenOpInferShape, ops::FlattenOpInferShape,
paddle::framework::DefaultGradOpDescMaker<true>); paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(flatten_grad, ops::FlattenGradOp, ops::FlattenGradInferShape); REGISTER_OPERATOR(flatten_grad, ops::FlattenGradOp, ops::FlattenGradInferShape);
REGISTER_OPERATOR(flatten2, ops::Flatten2Op, ops::Flatten2OpMaker,
ops::Flatten2OpInferShape, ops::Flatten2GradOpMaker);
REGISTER_OPERATOR(flatten2_grad, ops::Flatten2GradOp,
ops::Flatten2GradInferShape);
...@@ -246,6 +246,88 @@ class ReshapeGradKernel { ...@@ -246,6 +246,88 @@ class ReshapeGradKernel {
} }
}; };
// FIXME(zcd): reshape2 adds an intermediate output(XShape) based on reshape,
// the XShape is used to carry the shape and lod of X which will be used in
// reshape_grad, in this way, the framework can reuse the memory of X
// immediately the reshape_op is finished.
// Considering compatibility issues, we could not fix reshape_op
class Reshape2Op : public ReshapeOp {
public:
Reshape2Op(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: ReshapeOp(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override {
ReshapeOp::InferShape(ctx);
PADDLE_ENFORCE(ctx->HasOutput("XShape"),
"Output(XShape) of ReshapeOp should not be null.");
const auto &x_dims = ctx->GetInputDim("X");
std::vector<int64_t> xshape_dims(x_dims.size() + 1);
xshape_dims[0] = 0;
for (int i = 0; i < x_dims.size(); ++i) {
xshape_dims[i + 1] = x_dims[i];
}
ctx->SetOutputDim("XShape", framework::make_ddim(xshape_dims));
ctx->ShareLoD("X", /*->*/ "XShape");
}
};
class Reshape2OpMaker : public ReshapeOpMaker {
public:
void Make() override {
ReshapeOpMaker::Make();
AddOutput("XShape",
"XShape is just used to store the shape and lod of X, which will "
"be used in FlattenGradOp.")
.AsIntermediate();
}
};
class Reshape2GradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
std::unique_ptr<framework::OpDesc> Apply() const override {
auto *grad_op = new framework::OpDesc();
grad_op->SetType("reshape2_grad");
grad_op->SetInput("XShape", Output("XShape"));
grad_op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
grad_op->SetAttrMap(Attrs());
return std::unique_ptr<framework::OpDesc>(grad_op);
}
};
class Reshape2GradOp : public framework::OperatorWithKernel {
public:
Reshape2GradOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("XShape"), "Input(XShape) shouldn't be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) shouldn't be null.");
auto xshape_dims = ctx->GetInputDim("XShape");
auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size());
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
ctx->ShareLoD("XShape", framework::GradVarName("X"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))
->type()),
ctx.device_context());
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
...@@ -261,6 +343,17 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel, ...@@ -261,6 +343,17 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel,
ops::ReshapeGradKernel, int64_t, ops::ReshapeGradKernel, int64_t,
ops::ReshapeGradKernel); ops::ReshapeGradKernel);
REGISTER_OPERATOR(reshape2, ops::Reshape2Op, ops::Reshape2OpMaker,
ops::Reshape2GradMaker);
REGISTER_OPERATOR(reshape2_grad, ops::Reshape2GradOp);
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double,
ops::ReshapeKernel, int, ops::ReshapeKernel,
int64_t, ops::ReshapeKernel);
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2_grad, float, ops::ReshapeGradKernel,
double, ops::ReshapeGradKernel, int,
ops::ReshapeGradKernel, int64_t,
ops::ReshapeGradKernel);
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double, REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double,
ops::ReshapeKernel, int, ops::ReshapeKernel, ops::ReshapeKernel, int, ops::ReshapeKernel,
...@@ -269,4 +362,11 @@ REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel, ...@@ -269,4 +362,11 @@ REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel,
double, ops::ReshapeGradKernel, int, double, ops::ReshapeGradKernel, int,
ops::ReshapeGradKernel, int64_t, ops::ReshapeGradKernel, int64_t,
ops::ReshapeGradKernel); ops::ReshapeGradKernel);
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double,
ops::ReshapeKernel, int, ops::ReshapeKernel,
int64_t, ops::ReshapeKernel);
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2_grad, float, ops::ReshapeGradKernel,
double, ops::ReshapeGradKernel, int,
ops::ReshapeGradKernel, int64_t,
ops::ReshapeGradKernel);
#endif #endif
...@@ -181,6 +181,113 @@ class SqueezeGradOp : public framework::OperatorBase { ...@@ -181,6 +181,113 @@ class SqueezeGradOp : public framework::OperatorBase {
} }
}; };
// FIXME(zcd): squeeze2 adds an intermediate output(XShape) based on squeeze,
// the XShape is used to carry the shape and lod of X which will be used in
// squeeze_grad, in this way, the framework can reuse the memory of X
// immediately the squeeze2_op is finished.
// Considering compatibility issues, we could not fix squeeze2_op
class Squeeze2OpMaker : public SqueezeOpMaker {
public:
void Make() override {
SqueezeOpMaker::Make();
AddOutput("XShape",
"XShape is just used to store the shape and lod of X, which will "
"be used in SqueezeGradOp.")
.AsIntermediate();
}
};
class Squeeze2OpInferShape : public SqueezeOpInferShape {
public:
void operator()(framework::InferShapeContext *ctx) const override {
SqueezeOpInferShape::operator()(ctx);
PADDLE_ENFORCE(ctx->HasOutput("XShape"),
"Output(XShape) of Squeeze operator should not be null.");
const auto &x_dims = ctx->GetInputDim("X");
std::vector<int64_t> xshape_dims(x_dims.size() + 1);
xshape_dims[0] = 0;
for (int i = 0; i < x_dims.size(); ++i) {
xshape_dims[i + 1] = x_dims[i];
}
ctx->SetOutputDim("XShape", framework::make_ddim(xshape_dims));
ctx->ShareLoD("X", /*->*/ "XShape");
}
};
class Squeeze2Op : public framework::OperatorBase {
public:
using OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto &axes = Attr<std::vector<int>>("axes");
auto x_dims = scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims();
auto out_dims = Squeeze2OpInferShape::GetOutputShape(axes, x_dims);
framework::AttributeMap attrs;
attrs["shape"] = framework::vectorize2int(out_dims);
// Invoke Reshape Op
auto reshape_op = framework::OpRegistry::CreateOp(
"reshape2", {{"X", {Input("X")}}, {"Shape", {}}},
{{"Out", {Output("Out")}}, {"XShape", {Output("XShape")}}}, attrs);
reshape_op->Run(scope, place);
}
};
class Squeeze2GradOpMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
std::unique_ptr<framework::OpDesc> Apply() const override {
auto *grad_op = new framework::OpDesc();
grad_op->SetType("squeeze2_grad");
grad_op->SetInput("XShape", Output("XShape"));
grad_op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
grad_op->SetAttrMap(Attrs());
return std::unique_ptr<framework::OpDesc>(grad_op);
}
};
class Squeeze2GradInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE(context->HasInput("XShape"),
"Input(XShape) shouldn't be null.");
PADDLE_ENFORCE(context->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) shouldn't be null.");
auto xshape_dims = context->GetInputDim("XShape");
auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size());
context->SetOutputDim(framework::GradVarName("X"), x_dims);
context->ShareLoD("XShape", framework::GradVarName("X"));
}
};
class Squeeze2GradOp : public framework::OperatorBase {
public:
using OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto dx_name = Output(framework::GradVarName("X"));
auto dout_name = Input(framework::GradVarName("Out"));
auto xshape_name = Input("XShape");
auto xshape_dims =
scope.FindVar(xshape_name)->Get<framework::LoDTensor>().dims();
auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size());
framework::AttributeMap attrs;
attrs["shape"] = framework::vectorize2int(x_dims);
auto reshape_op = framework::OpRegistry::CreateOp(
"reshape2", {{"X", {dout_name}}, {"Shape", {}}},
{{"Out", {dx_name}}, {"XShape", {xshape_name}}}, attrs);
reshape_op->Run(scope, place);
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -192,3 +299,8 @@ REGISTER_OPERATOR(squeeze, ops::SqueezeOp, ops::SqueezeOpMaker, ...@@ -192,3 +299,8 @@ REGISTER_OPERATOR(squeeze, ops::SqueezeOp, ops::SqueezeOpMaker,
ops::SqueezeOpInferShape, ops::SqueezeOpInferShape,
paddle::framework::DefaultGradOpDescMaker<true>); paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(squeeze_grad, ops::SqueezeGradOp, ops::SqueezeGradInferShape); REGISTER_OPERATOR(squeeze_grad, ops::SqueezeGradOp, ops::SqueezeGradInferShape);
REGISTER_OPERATOR(squeeze2, ops::Squeeze2Op, ops::Squeeze2OpMaker,
ops::Squeeze2OpInferShape, ops::Squeeze2GradOpMaker);
REGISTER_OPERATOR(squeeze2_grad, ops::Squeeze2GradOp,
ops::Squeeze2GradInferShape);
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/transpose_op.h" #include "paddle/fluid/operators/transpose_op.h"
#include <string>
#include <vector> #include <vector>
namespace paddle { namespace paddle {
...@@ -24,7 +25,7 @@ class TransposeOp : public framework::OperatorWithKernel { ...@@ -24,7 +25,7 @@ class TransposeOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null"); PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
...@@ -101,7 +102,7 @@ class TransposeOpGrad : public framework::OperatorWithKernel { ...@@ -101,7 +102,7 @@ class TransposeOpGrad : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null"); "Input(Out@GRAD) should not be null");
...@@ -113,6 +114,93 @@ class TransposeOpGrad : public framework::OperatorWithKernel { ...@@ -113,6 +114,93 @@ class TransposeOpGrad : public framework::OperatorWithKernel {
} }
}; };
// FIXME(zcd): transpose2 adds an intermediate output(XShape) based on
// transpose, the XShape is used to carry the shape and lod of X which
// will be used in transpose_grad, in this way, the framework can reuse
// the memory of X immediately the transpose2_op is finished.
// Considering compatibility issues, we could not fix transpose2_op
class Transpose2Op : public TransposeOp {
public:
Transpose2Op(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: TransposeOp(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override {
TransposeOp::InferShape(ctx);
PADDLE_ENFORCE(ctx->HasOutput("XShape"),
"Output(XShape) should not be null");
const auto &in_dims = ctx->GetInputDim("X");
std::vector<int64_t> x_shape_dim(in_dims.size() + 1);
x_shape_dim[0] = 0;
for (int i = 0; i < in_dims.size(); ++i) {
x_shape_dim[i + 1] = in_dims[i];
}
ctx->SetOutputDim("XShape", framework::make_ddim(x_shape_dim));
ctx->ShareLoD("X", /*->*/ "XShape");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
ctx.device_context());
}
};
class Transpose2OpMaker : public TransposeOpMaker {
public:
void Make() override {
TransposeOpMaker::Make();
AddOutput("XShape", "(Tensor)The output tensor.").AsIntermediate();
}
};
class Transpose2GradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
std::unique_ptr<framework::OpDesc> Apply() const override {
auto *grad_op = new framework::OpDesc();
grad_op->SetType("transpose2_grad");
grad_op->SetInput("XShape", Output("XShape"));
grad_op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
grad_op->SetAttrMap(Attrs());
return std::unique_ptr<framework::OpDesc>(grad_op);
}
};
class Transpose2OpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("XShape"), "Input(XShape) should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
if (ctx->HasOutput(framework::GradVarName("X"))) {
auto xshape_dim = ctx->GetInputDim("XShape");
auto x_shape_dim =
framework::slice_ddim(xshape_dim, 1, xshape_dim.size());
ctx->SetOutputDim(framework::GradVarName("X"), x_shape_dim);
ctx->ShareLoD("XShape", framework::GradVarName("X"));
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))
->type()),
ctx.device_context());
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -120,8 +208,20 @@ namespace ops = paddle::operators; ...@@ -120,8 +208,20 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(transpose, ops::TransposeOp, ops::TransposeOpMaker, REGISTER_OPERATOR(transpose, ops::TransposeOp, ops::TransposeOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>); paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(transpose_grad, ops::TransposeOpGrad); REGISTER_OPERATOR(transpose_grad, ops::TransposeOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
transpose, ops::TransposeKernel<paddle::platform::CPUDeviceContext, float>); transpose, ops::TransposeKernel<paddle::platform::CPUDeviceContext, float>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
transpose_grad, transpose_grad,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, float>); ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, float>);
REGISTER_OPERATOR(transpose2, ops::Transpose2Op, ops::Transpose2OpMaker,
ops::Transpose2GradMaker);
REGISTER_OPERATOR(transpose2_grad, ops::Transpose2OpGrad);
REGISTER_OP_CPU_KERNEL(
transpose2,
ops::TransposeKernel<paddle::platform::CPUDeviceContext, float>);
REGISTER_OP_CPU_KERNEL(
transpose2_grad,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, float>);
...@@ -21,3 +21,10 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -21,3 +21,10 @@ REGISTER_OP_CUDA_KERNEL(
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
transpose_grad, transpose_grad,
ops::TransposeGradKernel<paddle::platform::CUDADeviceContext, float>); ops::TransposeGradKernel<paddle::platform::CUDADeviceContext, float>);
REGISTER_OP_CUDA_KERNEL(
transpose2,
ops::TransposeKernel<paddle::platform::CUDADeviceContext, float>);
REGISTER_OP_CUDA_KERNEL(
transpose2_grad,
ops::TransposeGradKernel<paddle::platform::CUDADeviceContext, float>);
...@@ -168,6 +168,112 @@ class UnsqueezeGradOp : public framework::OperatorBase { ...@@ -168,6 +168,112 @@ class UnsqueezeGradOp : public framework::OperatorBase {
} }
}; };
// FIXME(zcd): unsqueeze2 adds an intermediate output(XShape) based on
// unsqueeze, the XShape is used to carry the shape and lod of X which
// will be used in unsqueeze_grad, in this way, the framework can reuse
// the memory of X immediately the unsqueeze2_op is finished.
// Considering compatibility issues, we could not fix unsqueeze2_op
class Unsqueeze2OpInferShape : public UnsqueezeOpInferShape {
public:
void operator()(framework::InferShapeContext *ctx) const override {
UnsqueezeOpInferShape::operator()(ctx);
PADDLE_ENFORCE(ctx->HasOutput("XShape"),
"Output(XShape) of Unsqueeze operator should not be null.");
const auto &x_dims = ctx->GetInputDim("X");
std::vector<int64_t> xshape_dims(x_dims.size() + 1);
xshape_dims[0] = 0;
for (int i = 0; i < x_dims.size(); ++i) {
xshape_dims[i + 1] = x_dims[i];
}
ctx->SetOutputDim("XShape", framework::make_ddim(xshape_dims));
ctx->ShareLoD("X", /*->*/ "XShape");
}
};
class Unsqueeze2OpMaker : public UnsqueezeOpMaker {
public:
void Make() override {
UnsqueezeOpMaker::Make();
AddOutput("XShape",
"XShape is just used to store the shape and lod of X, which will "
"be used in UnsqueezeGradOp.")
.AsIntermediate();
}
};
class Unsqueeze2Op : public framework::OperatorBase {
public:
using OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto &axes = Attr<std::vector<int>>("axes");
auto x_dims = scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims();
auto out_dims = Unsqueeze2OpInferShape::GetOutputShape(axes, x_dims);
framework::AttributeMap attrs;
attrs["shape"] = framework::vectorize2int(out_dims);
// Invoke Reshape op.
auto reshape_op = framework::OpRegistry::CreateOp(
"reshape2", {{"X", {Input("X")}}, {"Shape", {}}},
{{"Out", {Output("Out")}}, {"XShape", {Output("XShape")}}}, attrs);
reshape_op->Run(scope, place);
}
};
class Unsqueeze2GradOpMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
std::unique_ptr<framework::OpDesc> Apply() const override {
auto *grad_op = new framework::OpDesc();
grad_op->SetType("unsqueeze2_grad");
grad_op->SetInput("XShape", Output("XShape"));
grad_op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
grad_op->SetAttrMap(Attrs());
return std::unique_ptr<framework::OpDesc>(grad_op);
}
};
class Unsqueeze2GradInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE(context->HasInput("XShape"),
"Input(XShape) shouldn't be null.");
PADDLE_ENFORCE(context->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) shouldn't be null.");
auto xshape_dims = context->GetInputDim("XShape");
auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size());
context->SetOutputDim(framework::GradVarName("X"), x_dims);
context->ShareLoD("XShape", framework::GradVarName("X"));
}
};
class Unsqueeze2GradOp : public framework::OperatorBase {
public:
using OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto dx_name = Output(framework::GradVarName("X"));
auto dout_name = Input(framework::GradVarName("Out"));
auto xshape_name = Input("XShape");
auto xshape_dims =
scope.FindVar(xshape_name)->Get<framework::LoDTensor>().dims();
auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size());
framework::AttributeMap attrs;
attrs["shape"] = framework::vectorize2int(x_dims);
auto reshape_op = framework::OpRegistry::CreateOp(
"reshape2", {{"X", {dout_name}}, {"Shape", {}}},
{{"Out", {dx_name}}, {"XShape", {xshape_name}}}, attrs);
reshape_op->Run(scope, place);
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -180,3 +286,8 @@ REGISTER_OPERATOR(unsqueeze, ops::UnsqueezeOp, ops::UnsqueezeOpMaker, ...@@ -180,3 +286,8 @@ REGISTER_OPERATOR(unsqueeze, ops::UnsqueezeOp, ops::UnsqueezeOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>); paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(unsqueeze_grad, ops::UnsqueezeGradOp, REGISTER_OPERATOR(unsqueeze_grad, ops::UnsqueezeGradOp,
ops::UnsqueezeGradInferShape); ops::UnsqueezeGradInferShape);
REGISTER_OPERATOR(unsqueeze2, ops::Unsqueeze2Op, ops::Unsqueeze2OpMaker,
ops::Unsqueeze2OpInferShape, ops::Unsqueeze2GradOpMaker);
REGISTER_OPERATOR(unsqueeze2_grad, ops::Unsqueeze2GradOp,
ops::Unsqueeze2GradInferShape);
...@@ -4025,10 +4025,12 @@ def transpose(x, perm, name=None): ...@@ -4025,10 +4025,12 @@ def transpose(x, perm, name=None):
helper = LayerHelper('transpose', **locals()) helper = LayerHelper('transpose', **locals())
out = helper.create_tmp_variable(x.dtype) out = helper.create_tmp_variable(x.dtype)
x_shape = helper.create_tmp_variable(x.dtype)
helper.append_op( helper.append_op(
type='transpose', type='transpose2',
inputs={'X': [x]}, inputs={'X': [x]},
outputs={'Out': [out]}, outputs={'Out': [out],
'XShape': [x_shape]},
attrs={'axis': perm}) attrs={'axis': perm})
return out return out
...@@ -4520,13 +4522,15 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=True, name=None): ...@@ -4520,13 +4522,15 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=True, name=None):
"Each dimension size given in shape must not be negtive " "Each dimension size given in shape must not be negtive "
"except one unknown dimension.") "except one unknown dimension.")
helper = LayerHelper("reshape", **locals()) helper = LayerHelper("reshape2", **locals())
out = helper.create_tmp_variable(dtype=x.dtype) out = helper.create_tmp_variable(dtype=x.dtype)
x_shape = helper.create_tmp_variable(dtype=x.dtype)
helper.append_op( helper.append_op(
type="reshape", type="reshape2",
inputs=inputs, inputs=inputs,
attrs={"shape": shape}, attrs={"shape": shape},
outputs={"Out": out}) outputs={"Out": out,
"XShape": x_shape})
return helper.append_activation(out) return helper.append_activation(out)
...@@ -4570,11 +4574,13 @@ def squeeze(input, axes, name=None): ...@@ -4570,11 +4574,13 @@ def squeeze(input, axes, name=None):
""" """
helper = LayerHelper("squeeze", **locals()) helper = LayerHelper("squeeze", **locals())
out = helper.create_tmp_variable(dtype=input.dtype) out = helper.create_tmp_variable(dtype=input.dtype)
x_shape = helper.create_tmp_variable(dtype=input.dtype)
helper.append_op( helper.append_op(
type="squeeze", type="squeeze2",
inputs={"X": input}, inputs={"X": input},
attrs={"axes": axes}, attrs={"axes": axes},
outputs={"Out": out}) outputs={"Out": out,
"XShape": x_shape})
return out return out
...@@ -4605,11 +4611,13 @@ def unsqueeze(input, axes, name=None): ...@@ -4605,11 +4611,13 @@ def unsqueeze(input, axes, name=None):
""" """
helper = LayerHelper("unsqueeze", **locals()) helper = LayerHelper("unsqueeze", **locals())
out = helper.create_tmp_variable(dtype=input.dtype) out = helper.create_tmp_variable(dtype=input.dtype)
x_shape = helper.create_tmp_variable(dtype=input.dtype)
helper.append_op( helper.append_op(
type="unsqueeze", type="unsqueeze2",
inputs={"X": input}, inputs={"X": input},
attrs={"axes": axes}, attrs={"axes": axes},
outputs={"Out": out}) outputs={"Out": out,
"XShape": x_shape})
return out return out
...@@ -5811,10 +5819,12 @@ def flatten(x, axis=1, name=None): ...@@ -5811,10 +5819,12 @@ def flatten(x, axis=1, name=None):
raise ValueError("The axis should be a int, and in range [0, rank(x)]") raise ValueError("The axis should be a int, and in range [0, rank(x)]")
out = helper.create_tmp_variable(x.dtype) out = helper.create_tmp_variable(x.dtype)
x_shape = helper.create_tmp_variable(x.dtype)
helper.append_op( helper.append_op(
type='flatten', type='flatten2',
inputs={"X": x}, inputs={"X": x},
outputs={'Out': out}, outputs={'Out': out,
'XShape': x_shape},
attrs={"axis": axis}) attrs={"axis": axis})
return out return out
......
...@@ -249,7 +249,7 @@ class OpTest(unittest.TestCase): ...@@ -249,7 +249,7 @@ class OpTest(unittest.TestCase):
outs, _ = self._calc_output(place) outs, _ = self._calc_output(place)
return outs return outs
def _calc_output(self, place, parallel=False): def _calc_output(self, place, parallel=False, no_check_set=None):
program = Program() program = Program()
block = program.global_block() block = program.global_block()
...@@ -273,6 +273,8 @@ class OpTest(unittest.TestCase): ...@@ -273,6 +273,8 @@ class OpTest(unittest.TestCase):
# if not, fill the fetch_list by the user configured outputs in test. # if not, fill the fetch_list by the user configured outputs in test.
if len(fetch_list) == 0: if len(fetch_list) == 0:
for var_name, var in six.iteritems(outputs): for var_name, var in six.iteritems(outputs):
if no_check_set is not None and var_name in no_check_set:
continue
if isinstance(var, list): if isinstance(var, list):
for v in var: for v in var:
fetch_list.append(v) fetch_list.append(v)
...@@ -291,11 +293,17 @@ class OpTest(unittest.TestCase): ...@@ -291,11 +293,17 @@ class OpTest(unittest.TestCase):
return_numpy=False) return_numpy=False)
return outs, fetch_list return outs, fetch_list
def check_output_with_place(self, place, atol, equal_nan=False): def check_output_with_place(self,
outs, fetch_list = self._calc_output(place) place,
atol,
no_check_set=None,
equal_nan=False):
outs, fetch_list = self._calc_output(place, no_check_set=no_check_set)
for out_name, out_dup in Operator.get_op_outputs(self.op_type): for out_name, out_dup in Operator.get_op_outputs(self.op_type):
if out_name not in self.outputs: if out_name not in self.outputs:
continue continue
if no_check_set is not None and out_name in no_check_set:
continue
def find_actual(target_name, fetch_list): def find_actual(target_name, fetch_list):
found = [ found = [
...@@ -360,10 +368,10 @@ class OpTest(unittest.TestCase): ...@@ -360,10 +368,10 @@ class OpTest(unittest.TestCase):
places.append(core.CUDAPlace(0)) places.append(core.CUDAPlace(0))
return places return places
def check_output(self, atol=1e-5, equal_nan=False): def check_output(self, atol=1e-5, no_check_set=None, equal_nan=False):
places = self._get_places() places = self._get_places()
for place in places: for place in places:
self.check_output_with_place(place, atol, equal_nan) self.check_output_with_place(place, atol, no_check_set, equal_nan)
def check_output_customized(self, checker): def check_output_customized(self, checker):
places = self._get_places() places = self._get_places()
......
...@@ -22,14 +22,17 @@ from op_test import OpTest ...@@ -22,14 +22,17 @@ from op_test import OpTest
class TestFlattenOp(OpTest): class TestFlattenOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "flatten" self.op_type = "flatten2"
self.init_test_case() self.init_test_case()
self.inputs = {"X": np.random.random(self.in_shape).astype("float32")} self.inputs = {"X": np.random.random(self.in_shape).astype("float32")}
self.init_attrs() self.init_attrs()
self.outputs = {"Out": self.inputs["X"].reshape(self.new_shape)} self.outputs = {
"Out": self.inputs["X"].reshape(self.new_shape),
"XShape": np.random.random(self.in_shape).astype("float32")
}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(no_check_set=["XShape"])
def test_check_grad(self): def test_check_grad(self):
self.check_grad(["X"], "Out") self.check_grad(["X"], "Out")
......
...@@ -22,106 +22,39 @@ from op_test import OpTest ...@@ -22,106 +22,39 @@ from op_test import OpTest
class TestReshapeOp(OpTest): class TestReshapeOp(OpTest):
def setUp(self): def setUp(self):
ori_shape = (2, 25) self.init_data()
new_shape = (5, 10) self.op_type = "reshape2"
self.inputs = {"X": np.random.random(self.ori_shape).astype("float32")}
self.op_type = "reshape" self.attrs = {"shape": self.new_shape}
self.inputs = {"X": np.random.random(ori_shape).astype("float32")} self.outputs = {
self.attrs = {"shape": new_shape} "Out": self.inputs["X"].reshape(self.infered_shape),
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} 'XShape': np.random.random(self.ori_shape).astype("float32")
}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
class TestReshapeOpDimInfer1(OpTest):
def setUp(self):
ori_shape = (5, 10)
new_shape = (5, -1, 5)
self.op_type = "reshape"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"shape": new_shape}
self.outputs = {"Out": self.inputs["X"].reshape(self.attrs["shape"])}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
class TestReshapeOpDimInfer2(OpTest):
def setUp(self):
ori_shape = (2, 2, 6)
new_shape = (2, 0, 3, -1)
infered_shape = (2, 2, 3, -1)
self.op_type = "reshape"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"shape": new_shape}
self.outputs = {"Out": self.inputs["X"].reshape(infered_shape)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
class TestReshapeOpInplace(OpTest):
def setUp(self):
ori_shape = (2, 25)
new_shape = (5, 10)
self.op_type = "reshape"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"shape": new_shape}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
class TestReshapeOpDimInferInplace1(OpTest):
def setUp(self):
ori_shape = (5, 10)
new_shape = (5, -1, 5)
self.op_type = "reshape" def init_data(self):
self.inputs = {"X": np.random.random(ori_shape).astype("float32")} self.ori_shape = (2, 25)
self.attrs = {"shape": new_shape} self.new_shape = (5, 10)
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} self.infered_shape = (5, 10)
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(no_check_set=['XShape'])
def test_check_grad(self): def test_check_grad(self):
self.check_grad(["X"], "Out") self.check_grad(["X"], "Out")
class TestReshapeOpDimInferInplace2(OpTest): class TestReshapeOpDimInfer1(TestReshapeOp):
def setUp(self): def init_data(self):
ori_shape = (2, 2, 6) self.ori_shape = (5, 10)
new_shape = (2, 0, 3, -1) self.new_shape = (5, -1, 5)
infered_shape = (2, 2, 3, -1) self.infered_shape = (5, -1, 5)
self.op_type = "reshape"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"shape": new_shape}
self.outputs = {"Out": self.inputs["X"].reshape(infered_shape)}
def test_check_output(self):
self.check_output()
def test_check_grad(self): class TestReshapeOpDimInfer2(TestReshapeOp):
self.check_grad(["X"], "Out") def init_data(self):
self.ori_shape = (2, 2, 6)
self.new_shape = (2, 0, 3, -1)
self.infered_shape = (2, 2, 3, -1)
class TestReshapeOpWithInputShape(OpTest): class TestReshapeOpWithInputShape(OpTest):
...@@ -130,20 +63,23 @@ class TestReshapeOpWithInputShape(OpTest): ...@@ -130,20 +63,23 @@ class TestReshapeOpWithInputShape(OpTest):
new_shape = (0, -1, 5) new_shape = (0, -1, 5)
actual_shape = (2, 3, 5) actual_shape = (2, 3, 5)
self.op_type = "reshape" self.op_type = "reshape2"
self.inputs = { self.inputs = {
"X": np.random.random(ori_shape).astype("float32"), "X": np.random.random(ori_shape).astype("float32"),
"Shape": np.array( "Shape": np.array(
actual_shape, dtype="int32") actual_shape, dtype="int32")
} }
self.attrs = {"shape": new_shape} self.attrs = {"shape": new_shape}
self.outputs = {"Out": self.inputs["X"].reshape(actual_shape)} self.outputs = {
"Out": self.inputs["X"].reshape(actual_shape),
'XShape': np.random.random(ori_shape).astype("float32")
}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(no_check_set=['XShape'])
def test_check_grad(self): def test_check_grad(self):
self.check_grad(["X"], "Out") self.check_grad(["X"], "Out", sum_outputs=["Out"])
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -23,14 +23,17 @@ from op_test import OpTest ...@@ -23,14 +23,17 @@ from op_test import OpTest
# Correct: General. # Correct: General.
class TestSqueezeOp(OpTest): class TestSqueezeOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "squeeze" self.op_type = "squeeze2"
self.init_test_case() self.init_test_case()
self.inputs = {"X": np.random.random(self.ori_shape).astype("float32")} self.inputs = {"X": np.random.random(self.ori_shape).astype("float32")}
self.init_attrs() self.init_attrs()
self.outputs = {"Out": self.inputs["X"].reshape(self.new_shape)} self.outputs = {
"Out": self.inputs["X"].reshape(self.new_shape),
"XShape": np.random.random(self.ori_shape).astype("float32")
}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(no_check_set=['XShape'])
def test_check_grad(self): def test_check_grad(self):
self.check_grad(["X"], "Out") self.check_grad(["X"], "Out")
......
...@@ -22,16 +22,19 @@ from op_test import OpTest ...@@ -22,16 +22,19 @@ from op_test import OpTest
class TestTransposeOp(OpTest): class TestTransposeOp(OpTest):
def setUp(self): def setUp(self):
self.initTestCase() self.initTestCase()
self.op_type = "transpose" self.op_type = "transpose2"
self.inputs = {'X': np.random.random(self.shape).astype("float32")} self.inputs = {'X': np.random.random(self.shape).astype("float32")}
self.attrs = {'axis': list(self.axis)} self.attrs = {'axis': list(self.axis)}
self.outputs = {'Out': self.inputs['X'].transpose(self.axis)} self.outputs = {
'XShape': np.random.random(self.shape).astype("float32"),
'Out': self.inputs['X'].transpose(self.axis)
}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(no_check_set=['XShape'])
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out', sum_outputs=['Out'])
def initTestCase(self): def initTestCase(self):
self.shape = (3, 4) self.shape = (3, 4)
......
...@@ -24,13 +24,16 @@ from op_test import OpTest ...@@ -24,13 +24,16 @@ from op_test import OpTest
class TestUnsqueezeOp(OpTest): class TestUnsqueezeOp(OpTest):
def setUp(self): def setUp(self):
self.init_test_case() self.init_test_case()
self.op_type = "unsqueeze" self.op_type = "unsqueeze2"
self.inputs = {"X": np.random.random(self.ori_shape).astype("float32")} self.inputs = {"X": np.random.random(self.ori_shape).astype("float32")}
self.init_attrs() self.init_attrs()
self.outputs = {"Out": self.inputs["X"].reshape(self.new_shape)} self.outputs = {
"Out": self.inputs["X"].reshape(self.new_shape),
"XShape": np.random.random(self.ori_shape).astype("float32")
}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(no_check_set=["XShape"])
def test_check_grad(self): def test_check_grad(self):
self.check_grad(["X"], "Out") self.check_grad(["X"], "Out")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册