提交 e2e82bde 编写于 作者: M minqiyang

Accelerate Reshape op

上级 e1904ac2
......@@ -164,7 +164,7 @@ dimension value will be copied from Input(X) at runtime. Note that the index of
[2, 3, 4], Attr(shape) = [2, 3, 2, 0] is an invalid input.
3. Input(Shape) has a higher priority than Attr(shape) if it is provided, while
Attr(shape) still should be set correctly to gurantee shape inference in
Attr(shape) still should be set correctly to gurantee shape inference in
compile-time.
)DOC");
......@@ -195,6 +195,7 @@ class ReshapeGradOp : public framework::OperatorWithKernel {
}
};
template <typename T>
class ReshapeKernel {
public:
void operator()(const framework::ExecutionContext &ctx) const {
......@@ -227,12 +228,15 @@ class ReshapeKernel {
"sequence_reshape op.");
}
out->mutable_data(ctx.GetPlace(), in->type());
framework::TensorCopySync(*in, ctx.GetPlace(), out);
if (in->data<T>() !=
reinterpret_cast<T *>(out->mutable_data(ctx.GetPlace(), in->type()))) {
framework::TensorCopySync(*in, ctx.GetPlace(), out);
}
out->Resize(out_dims);
}
};
template <typename T>
class ReshapeGradKernel {
public:
void operator()(const framework::ExecutionContext &ctx) const {
......@@ -240,8 +244,9 @@ class ReshapeGradKernel {
auto *d_x = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto in_dims = d_x->dims();
d_x->mutable_data(ctx.GetPlace(), d_out->type());
framework::TensorCopySync(*d_out, ctx.GetPlace(), d_x);
if (d_out->data<T>() != d_x->mutable_data(ctx.GetPlace(), d_out->type())) {
framework::TensorCopySync(*d_out, ctx.GetPlace(), d_x);
}
d_x->Resize(in_dims);
}
};
......@@ -259,7 +264,6 @@ class Reshape2Op : public ReshapeOp {
: ReshapeOp(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override {
ReshapeOp::InferShape(ctx);
PADDLE_ENFORCE(ctx->HasOutput("XShape"),
"Output(XShape) of ReshapeOp should not be null.");
const auto &x_dims = ctx->GetInputDim("X");
......@@ -270,6 +274,8 @@ class Reshape2Op : public ReshapeOp {
}
ctx->SetOutputDim("XShape", framework::make_ddim(xshape_dims));
ctx->ShareLoD("X", /*->*/ "XShape");
ReshapeOp::InferShape(ctx);
}
};
......@@ -335,38 +341,46 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(reshape, ops::ReshapeOp, ops::ReshapeOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(reshape_grad, ops::ReshapeGradOp);
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double,
ops::ReshapeKernel, int, ops::ReshapeKernel,
int64_t, ops::ReshapeKernel);
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel,
double, ops::ReshapeGradKernel, int,
ops::ReshapeGradKernel, int64_t,
ops::ReshapeGradKernel);
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel<float>,
double, ops::ReshapeKernel<double>, int,
ops::ReshapeKernel<int>, int64_t,
ops::ReshapeKernel<int64_t>);
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape_grad, float,
ops::ReshapeGradKernel<float>, double,
ops::ReshapeGradKernel<double>, int,
ops::ReshapeGradKernel<int>, int64_t,
ops::ReshapeGradKernel<int64_t>);
REGISTER_OPERATOR(reshape2, ops::Reshape2Op, ops::Reshape2OpMaker,
ops::Reshape2GradMaker);
REGISTER_OPERATOR(reshape2_grad, ops::Reshape2GradOp);
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double,
ops::ReshapeKernel, int, ops::ReshapeKernel,
int64_t, ops::ReshapeKernel);
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2_grad, float, ops::ReshapeGradKernel,
double, ops::ReshapeGradKernel, int,
ops::ReshapeGradKernel, int64_t,
ops::ReshapeGradKernel);
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel<float>,
double, ops::ReshapeKernel<double>, int,
ops::ReshapeKernel<int>, int64_t,
ops::ReshapeKernel<int64_t>);
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2_grad, float,
ops::ReshapeGradKernel<float>, double,
ops::ReshapeGradKernel<double>, int,
ops::ReshapeGradKernel<int>, int64_t,
ops::ReshapeGradKernel<int64_t>);
#ifdef PADDLE_WITH_CUDA
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double,
ops::ReshapeKernel, int, ops::ReshapeKernel,
int64_t, ops::ReshapeKernel);
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel,
double, ops::ReshapeGradKernel, int,
ops::ReshapeGradKernel, int64_t,
ops::ReshapeGradKernel);
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double,
ops::ReshapeKernel, int, ops::ReshapeKernel,
int64_t, ops::ReshapeKernel);
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2_grad, float, ops::ReshapeGradKernel,
double, ops::ReshapeGradKernel, int,
ops::ReshapeGradKernel, int64_t,
ops::ReshapeGradKernel);
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel<float>,
double, ops::ReshapeKernel<double>, int,
ops::ReshapeKernel<int>, int64_t,
ops::ReshapeKernel<int64_t>);
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape_grad, float,
ops::ReshapeGradKernel<float>, double,
ops::ReshapeGradKernel<double>, int,
ops::ReshapeGradKernel<int>, int64_t,
ops::ReshapeGradKernel<int64_t>);
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel<float>,
double, ops::ReshapeKernel<double>, int,
ops::ReshapeKernel<int>, int64_t,
ops::ReshapeKernel<int64_t>);
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2_grad, float,
ops::ReshapeGradKernel<float>, double,
ops::ReshapeGradKernel<double>, int,
ops::ReshapeGradKernel<int>, int64_t,
ops::ReshapeGradKernel<int64_t>);
#endif
......@@ -90,11 +90,12 @@ REGISTER_OPERATOR(sequence_concat, paddle::framework::OperatorWithKernel,
paddle::framework::DefaultGradOpDescMaker<false>);
template <typename T>
using Kernel = op::SeqConcatKernel<paddle::platform::CPUDeviceContext, T>;
REGISTER_OP_CPU_KERNEL(sequence_concat, Kernel<float>, Kernel<double>);
REGISTER_OP_CPU_KERNEL(sequence_concat, Kernel<float>, Kernel<double>,
Kernel<int64_t>);
REGISTER_OPERATOR(sequence_concat_grad, paddle::framework::OperatorWithKernel,
op::SeqConcatGradShapeInferer);
template <typename T>
using GradKernel =
op::SeqConcatGradKernel<paddle::platform::CPUDeviceContext, T>;
REGISTER_OP_CPU_KERNEL(sequence_concat_grad, GradKernel<float>,
GradKernel<double>);
GradKernel<double>, GradKernel<int64_t>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册