提交 e2e82bde 编写于 作者: M minqiyang

Accelerate Reshape op

上级 e1904ac2
...@@ -195,6 +195,7 @@ class ReshapeGradOp : public framework::OperatorWithKernel { ...@@ -195,6 +195,7 @@ class ReshapeGradOp : public framework::OperatorWithKernel {
} }
}; };
template <typename T>
class ReshapeKernel { class ReshapeKernel {
public: public:
void operator()(const framework::ExecutionContext &ctx) const { void operator()(const framework::ExecutionContext &ctx) const {
...@@ -227,12 +228,15 @@ class ReshapeKernel { ...@@ -227,12 +228,15 @@ class ReshapeKernel {
"sequence_reshape op."); "sequence_reshape op.");
} }
out->mutable_data(ctx.GetPlace(), in->type()); if (in->data<T>() !=
reinterpret_cast<T *>(out->mutable_data(ctx.GetPlace(), in->type()))) {
framework::TensorCopySync(*in, ctx.GetPlace(), out); framework::TensorCopySync(*in, ctx.GetPlace(), out);
}
out->Resize(out_dims); out->Resize(out_dims);
} }
}; };
template <typename T>
class ReshapeGradKernel { class ReshapeGradKernel {
public: public:
void operator()(const framework::ExecutionContext &ctx) const { void operator()(const framework::ExecutionContext &ctx) const {
...@@ -240,8 +244,9 @@ class ReshapeGradKernel { ...@@ -240,8 +244,9 @@ class ReshapeGradKernel {
auto *d_x = ctx.Output<framework::Tensor>(framework::GradVarName("X")); auto *d_x = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto in_dims = d_x->dims(); auto in_dims = d_x->dims();
d_x->mutable_data(ctx.GetPlace(), d_out->type()); if (d_out->data<T>() != d_x->mutable_data(ctx.GetPlace(), d_out->type())) {
framework::TensorCopySync(*d_out, ctx.GetPlace(), d_x); framework::TensorCopySync(*d_out, ctx.GetPlace(), d_x);
}
d_x->Resize(in_dims); d_x->Resize(in_dims);
} }
}; };
...@@ -259,7 +264,6 @@ class Reshape2Op : public ReshapeOp { ...@@ -259,7 +264,6 @@ class Reshape2Op : public ReshapeOp {
: ReshapeOp(type, inputs, outputs, attrs) {} : ReshapeOp(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
ReshapeOp::InferShape(ctx);
PADDLE_ENFORCE(ctx->HasOutput("XShape"), PADDLE_ENFORCE(ctx->HasOutput("XShape"),
"Output(XShape) of ReshapeOp should not be null."); "Output(XShape) of ReshapeOp should not be null.");
const auto &x_dims = ctx->GetInputDim("X"); const auto &x_dims = ctx->GetInputDim("X");
...@@ -270,6 +274,8 @@ class Reshape2Op : public ReshapeOp { ...@@ -270,6 +274,8 @@ class Reshape2Op : public ReshapeOp {
} }
ctx->SetOutputDim("XShape", framework::make_ddim(xshape_dims)); ctx->SetOutputDim("XShape", framework::make_ddim(xshape_dims));
ctx->ShareLoD("X", /*->*/ "XShape"); ctx->ShareLoD("X", /*->*/ "XShape");
ReshapeOp::InferShape(ctx);
} }
}; };
...@@ -335,38 +341,46 @@ namespace ops = paddle::operators; ...@@ -335,38 +341,46 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(reshape, ops::ReshapeOp, ops::ReshapeOpMaker, REGISTER_OPERATOR(reshape, ops::ReshapeOp, ops::ReshapeOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>); paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(reshape_grad, ops::ReshapeGradOp); REGISTER_OPERATOR(reshape_grad, ops::ReshapeGradOp);
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double, REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel<float>,
ops::ReshapeKernel, int, ops::ReshapeKernel, double, ops::ReshapeKernel<double>, int,
int64_t, ops::ReshapeKernel); ops::ReshapeKernel<int>, int64_t,
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel, ops::ReshapeKernel<int64_t>);
double, ops::ReshapeGradKernel, int, REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape_grad, float,
ops::ReshapeGradKernel, int64_t, ops::ReshapeGradKernel<float>, double,
ops::ReshapeGradKernel); ops::ReshapeGradKernel<double>, int,
ops::ReshapeGradKernel<int>, int64_t,
ops::ReshapeGradKernel<int64_t>);
REGISTER_OPERATOR(reshape2, ops::Reshape2Op, ops::Reshape2OpMaker, REGISTER_OPERATOR(reshape2, ops::Reshape2Op, ops::Reshape2OpMaker,
ops::Reshape2GradMaker); ops::Reshape2GradMaker);
REGISTER_OPERATOR(reshape2_grad, ops::Reshape2GradOp); REGISTER_OPERATOR(reshape2_grad, ops::Reshape2GradOp);
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double, REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel<float>,
ops::ReshapeKernel, int, ops::ReshapeKernel, double, ops::ReshapeKernel<double>, int,
int64_t, ops::ReshapeKernel); ops::ReshapeKernel<int>, int64_t,
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2_grad, float, ops::ReshapeGradKernel, ops::ReshapeKernel<int64_t>);
double, ops::ReshapeGradKernel, int, REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2_grad, float,
ops::ReshapeGradKernel, int64_t, ops::ReshapeGradKernel<float>, double,
ops::ReshapeGradKernel); ops::ReshapeGradKernel<double>, int,
ops::ReshapeGradKernel<int>, int64_t,
ops::ReshapeGradKernel<int64_t>);
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double, REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel<float>,
ops::ReshapeKernel, int, ops::ReshapeKernel, double, ops::ReshapeKernel<double>, int,
int64_t, ops::ReshapeKernel); ops::ReshapeKernel<int>, int64_t,
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel, ops::ReshapeKernel<int64_t>);
double, ops::ReshapeGradKernel, int, REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape_grad, float,
ops::ReshapeGradKernel, int64_t, ops::ReshapeGradKernel<float>, double,
ops::ReshapeGradKernel); ops::ReshapeGradKernel<double>, int,
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double, ops::ReshapeGradKernel<int>, int64_t,
ops::ReshapeKernel, int, ops::ReshapeKernel, ops::ReshapeGradKernel<int64_t>);
int64_t, ops::ReshapeKernel); REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel<float>,
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2_grad, float, ops::ReshapeGradKernel, double, ops::ReshapeKernel<double>, int,
double, ops::ReshapeGradKernel, int, ops::ReshapeKernel<int>, int64_t,
ops::ReshapeGradKernel, int64_t, ops::ReshapeKernel<int64_t>);
ops::ReshapeGradKernel); REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2_grad, float,
ops::ReshapeGradKernel<float>, double,
ops::ReshapeGradKernel<double>, int,
ops::ReshapeGradKernel<int>, int64_t,
ops::ReshapeGradKernel<int64_t>);
#endif #endif
...@@ -90,11 +90,12 @@ REGISTER_OPERATOR(sequence_concat, paddle::framework::OperatorWithKernel, ...@@ -90,11 +90,12 @@ REGISTER_OPERATOR(sequence_concat, paddle::framework::OperatorWithKernel,
paddle::framework::DefaultGradOpDescMaker<false>); paddle::framework::DefaultGradOpDescMaker<false>);
template <typename T> template <typename T>
using Kernel = op::SeqConcatKernel<paddle::platform::CPUDeviceContext, T>; using Kernel = op::SeqConcatKernel<paddle::platform::CPUDeviceContext, T>;
REGISTER_OP_CPU_KERNEL(sequence_concat, Kernel<float>, Kernel<double>); REGISTER_OP_CPU_KERNEL(sequence_concat, Kernel<float>, Kernel<double>,
Kernel<int64_t>);
REGISTER_OPERATOR(sequence_concat_grad, paddle::framework::OperatorWithKernel, REGISTER_OPERATOR(sequence_concat_grad, paddle::framework::OperatorWithKernel,
op::SeqConcatGradShapeInferer); op::SeqConcatGradShapeInferer);
template <typename T> template <typename T>
using GradKernel = using GradKernel =
op::SeqConcatGradKernel<paddle::platform::CPUDeviceContext, T>; op::SeqConcatGradKernel<paddle::platform::CPUDeviceContext, T>;
REGISTER_OP_CPU_KERNEL(sequence_concat_grad, GradKernel<float>, REGISTER_OP_CPU_KERNEL(sequence_concat_grad, GradKernel<float>,
GradKernel<double>); GradKernel<double>, GradKernel<int64_t>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册