提交 e2e82bde 编写于 作者: M minqiyang

Accelerate Reshape op

上级 e1904ac2
...@@ -164,7 +164,7 @@ dimension value will be copied from Input(X) at runtime. Note that the index of ...@@ -164,7 +164,7 @@ dimension value will be copied from Input(X) at runtime. Note that the index of
[2, 3, 4], Attr(shape) = [2, 3, 2, 0] is an invalid input. [2, 3, 4], Attr(shape) = [2, 3, 2, 0] is an invalid input.
3. Input(Shape) has a higher priority than Attr(shape) if it is provided, while 3. Input(Shape) has a higher priority than Attr(shape) if it is provided, while
Attr(shape) still should be set correctly to gurantee shape inference in Attr(shape) still should be set correctly to gurantee shape inference in
compile-time. compile-time.
)DOC"); )DOC");
...@@ -195,6 +195,7 @@ class ReshapeGradOp : public framework::OperatorWithKernel { ...@@ -195,6 +195,7 @@ class ReshapeGradOp : public framework::OperatorWithKernel {
} }
}; };
template <typename T>
class ReshapeKernel { class ReshapeKernel {
public: public:
void operator()(const framework::ExecutionContext &ctx) const { void operator()(const framework::ExecutionContext &ctx) const {
...@@ -227,12 +228,15 @@ class ReshapeKernel { ...@@ -227,12 +228,15 @@ class ReshapeKernel {
"sequence_reshape op."); "sequence_reshape op.");
} }
out->mutable_data(ctx.GetPlace(), in->type()); if (in->data<T>() !=
framework::TensorCopySync(*in, ctx.GetPlace(), out); reinterpret_cast<T *>(out->mutable_data(ctx.GetPlace(), in->type()))) {
framework::TensorCopySync(*in, ctx.GetPlace(), out);
}
out->Resize(out_dims); out->Resize(out_dims);
} }
}; };
template <typename T>
class ReshapeGradKernel { class ReshapeGradKernel {
public: public:
void operator()(const framework::ExecutionContext &ctx) const { void operator()(const framework::ExecutionContext &ctx) const {
...@@ -240,8 +244,9 @@ class ReshapeGradKernel { ...@@ -240,8 +244,9 @@ class ReshapeGradKernel {
auto *d_x = ctx.Output<framework::Tensor>(framework::GradVarName("X")); auto *d_x = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto in_dims = d_x->dims(); auto in_dims = d_x->dims();
d_x->mutable_data(ctx.GetPlace(), d_out->type()); if (d_out->data<T>() != d_x->mutable_data(ctx.GetPlace(), d_out->type())) {
framework::TensorCopySync(*d_out, ctx.GetPlace(), d_x); framework::TensorCopySync(*d_out, ctx.GetPlace(), d_x);
}
d_x->Resize(in_dims); d_x->Resize(in_dims);
} }
}; };
...@@ -259,7 +264,6 @@ class Reshape2Op : public ReshapeOp { ...@@ -259,7 +264,6 @@ class Reshape2Op : public ReshapeOp {
: ReshapeOp(type, inputs, outputs, attrs) {} : ReshapeOp(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
ReshapeOp::InferShape(ctx);
PADDLE_ENFORCE(ctx->HasOutput("XShape"), PADDLE_ENFORCE(ctx->HasOutput("XShape"),
"Output(XShape) of ReshapeOp should not be null."); "Output(XShape) of ReshapeOp should not be null.");
const auto &x_dims = ctx->GetInputDim("X"); const auto &x_dims = ctx->GetInputDim("X");
...@@ -270,6 +274,8 @@ class Reshape2Op : public ReshapeOp { ...@@ -270,6 +274,8 @@ class Reshape2Op : public ReshapeOp {
} }
ctx->SetOutputDim("XShape", framework::make_ddim(xshape_dims)); ctx->SetOutputDim("XShape", framework::make_ddim(xshape_dims));
ctx->ShareLoD("X", /*->*/ "XShape"); ctx->ShareLoD("X", /*->*/ "XShape");
ReshapeOp::InferShape(ctx);
} }
}; };
...@@ -335,38 +341,46 @@ namespace ops = paddle::operators; ...@@ -335,38 +341,46 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(reshape, ops::ReshapeOp, ops::ReshapeOpMaker, REGISTER_OPERATOR(reshape, ops::ReshapeOp, ops::ReshapeOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>); paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(reshape_grad, ops::ReshapeGradOp); REGISTER_OPERATOR(reshape_grad, ops::ReshapeGradOp);
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double, REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel<float>,
ops::ReshapeKernel, int, ops::ReshapeKernel, double, ops::ReshapeKernel<double>, int,
int64_t, ops::ReshapeKernel); ops::ReshapeKernel<int>, int64_t,
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel, ops::ReshapeKernel<int64_t>);
double, ops::ReshapeGradKernel, int, REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape_grad, float,
ops::ReshapeGradKernel, int64_t, ops::ReshapeGradKernel<float>, double,
ops::ReshapeGradKernel); ops::ReshapeGradKernel<double>, int,
ops::ReshapeGradKernel<int>, int64_t,
ops::ReshapeGradKernel<int64_t>);
REGISTER_OPERATOR(reshape2, ops::Reshape2Op, ops::Reshape2OpMaker, REGISTER_OPERATOR(reshape2, ops::Reshape2Op, ops::Reshape2OpMaker,
ops::Reshape2GradMaker); ops::Reshape2GradMaker);
REGISTER_OPERATOR(reshape2_grad, ops::Reshape2GradOp); REGISTER_OPERATOR(reshape2_grad, ops::Reshape2GradOp);
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double, REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel<float>,
ops::ReshapeKernel, int, ops::ReshapeKernel, double, ops::ReshapeKernel<double>, int,
int64_t, ops::ReshapeKernel); ops::ReshapeKernel<int>, int64_t,
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2_grad, float, ops::ReshapeGradKernel, ops::ReshapeKernel<int64_t>);
double, ops::ReshapeGradKernel, int, REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2_grad, float,
ops::ReshapeGradKernel, int64_t, ops::ReshapeGradKernel<float>, double,
ops::ReshapeGradKernel); ops::ReshapeGradKernel<double>, int,
ops::ReshapeGradKernel<int>, int64_t,
ops::ReshapeGradKernel<int64_t>);
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double, REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel<float>,
ops::ReshapeKernel, int, ops::ReshapeKernel, double, ops::ReshapeKernel<double>, int,
int64_t, ops::ReshapeKernel); ops::ReshapeKernel<int>, int64_t,
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel, ops::ReshapeKernel<int64_t>);
double, ops::ReshapeGradKernel, int, REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape_grad, float,
ops::ReshapeGradKernel, int64_t, ops::ReshapeGradKernel<float>, double,
ops::ReshapeGradKernel); ops::ReshapeGradKernel<double>, int,
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double, ops::ReshapeGradKernel<int>, int64_t,
ops::ReshapeKernel, int, ops::ReshapeKernel, ops::ReshapeGradKernel<int64_t>);
int64_t, ops::ReshapeKernel); REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel<float>,
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2_grad, float, ops::ReshapeGradKernel, double, ops::ReshapeKernel<double>, int,
double, ops::ReshapeGradKernel, int, ops::ReshapeKernel<int>, int64_t,
ops::ReshapeGradKernel, int64_t, ops::ReshapeKernel<int64_t>);
ops::ReshapeGradKernel); REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2_grad, float,
ops::ReshapeGradKernel<float>, double,
ops::ReshapeGradKernel<double>, int,
ops::ReshapeGradKernel<int>, int64_t,
ops::ReshapeGradKernel<int64_t>);
#endif #endif
...@@ -90,11 +90,12 @@ REGISTER_OPERATOR(sequence_concat, paddle::framework::OperatorWithKernel, ...@@ -90,11 +90,12 @@ REGISTER_OPERATOR(sequence_concat, paddle::framework::OperatorWithKernel,
paddle::framework::DefaultGradOpDescMaker<false>); paddle::framework::DefaultGradOpDescMaker<false>);
template <typename T> template <typename T>
using Kernel = op::SeqConcatKernel<paddle::platform::CPUDeviceContext, T>; using Kernel = op::SeqConcatKernel<paddle::platform::CPUDeviceContext, T>;
REGISTER_OP_CPU_KERNEL(sequence_concat, Kernel<float>, Kernel<double>); REGISTER_OP_CPU_KERNEL(sequence_concat, Kernel<float>, Kernel<double>,
Kernel<int64_t>);
REGISTER_OPERATOR(sequence_concat_grad, paddle::framework::OperatorWithKernel, REGISTER_OPERATOR(sequence_concat_grad, paddle::framework::OperatorWithKernel,
op::SeqConcatGradShapeInferer); op::SeqConcatGradShapeInferer);
template <typename T> template <typename T>
using GradKernel = using GradKernel =
op::SeqConcatGradKernel<paddle::platform::CPUDeviceContext, T>; op::SeqConcatGradKernel<paddle::platform::CPUDeviceContext, T>;
REGISTER_OP_CPU_KERNEL(sequence_concat_grad, GradKernel<float>, REGISTER_OP_CPU_KERNEL(sequence_concat_grad, GradKernel<float>,
GradKernel<double>); GradKernel<double>, GradKernel<int64_t>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册