提交 352fa41a 编写于 作者: Y yangyaming

Finish adapting forward.

上级 5a159f34
...@@ -17,7 +17,7 @@ limitations under the License. */ ...@@ -17,7 +17,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using framework::Tensor; using framework::LoDTensor;
class SequenceExpandOp : public framework::OperatorWithKernel { class SequenceExpandOp : public framework::OperatorWithKernel {
public: public:
...@@ -25,15 +25,67 @@ class SequenceExpandOp : public framework::OperatorWithKernel { ...@@ -25,15 +25,67 @@ class SequenceExpandOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X")); PADDLE_ENFORCE(ctx->HasInput("X"),
PADDLE_ENFORCE(ctx->HasOutput("Out")); "Input(X) of SequenceExpandOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Y")); PADDLE_ENFORCE(ctx->HasInput("Y"),
framework::DDim out_dim; "Input(Y) of SequenceExpandOp should not be null.");
auto y_dim = ctx->GetInputDim("Y"); PADDLE_ENFORCE(ctx->HasOutput("Out"),
out_dim = ctx->GetInputDim("X"); "Output(Out) of SequenceExpandOp should not be null.");
out_dim[0] = y_dim[0];
ctx->ShareLoD("Y", "Out"); auto x_dims = ctx->GetInputDim("X");
ctx->SetOutputDim("Out", out_dim); PADDLE_ENFORCE_EQ(x_dims.size(), 2U,
"Dimension number of Input(X) should be 2.");
int ref_level = ctx->Attrs().Get<int>("ref_level");
if (ctx->IsRuntime()) {
framework::Variable* x_var =
boost::get<framework::Variable*>(ctx->GetInputVarPtrs("X")[0]);
framework::Variable* y_var =
boost::get<framework::Variable*>(ctx->GetInputVarPtrs("Y")[0]);
auto& x_lod = x_var->Get<LoDTensor>().lod();
auto& y_lod = y_var->Get<LoDTensor>().lod();
PADDLE_ENFORCE_LE(x_lod.size(), 1,
"Number of lod level of Input(X) should not be "
"greater than 1.");
PADDLE_ENFORCE(x_lod.size() == y_lod.size() || x_lod.size() == 0,
"Number of lod level of Input(X) either equal to 0 "
"or equal to that of Input(Y).");
int64_t out_first_dim = 0;
if (y_lod[ref_level].size() < 1) {
out_first_dim = x_dims[0];
} else {
if (x_lod.size() == 1) { // X is LoDTensor
for (size_t i = 1; i < y_lod[ref_level].size(); ++i) {
int x_seq_len = x_lod[0][i] - x_lod[0][i - 1];
out_first_dim +=
(y_lod[ref_level][i] - y_lod[ref_level][i - 1]) * x_seq_len;
}
} else { // X is normal Tensor
for (size_t i = 1; i < y_lod[ref_level].size(); ++i) {
out_first_dim += y_lod[ref_level][i] - y_lod[ref_level][i - 1];
}
}
}
ctx->SetOutputDim("Out", {out_first_dim, x_dims[1]});
} else {
framework::VarDesc* in_reader =
boost::get<framework::VarDesc*>(ctx->GetInputVarPtrs("Y")[0]);
int lod_level_num = in_reader->GetLoDLevels().size();
PADDLE_ENFORCE_GE(ref_level, 0,
"Level of referred lod should be greater or "
"equal to 0.");
PADDLE_ENFORCE_LT(ref_level, lod_level_num,
"Level of referred lod should be smaller than "
"level number of Input(Y).");
ctx->SetOutputDim("Out", {-1, x_dims[1]});
}
} }
}; };
...@@ -42,17 +94,15 @@ class SequenceExpandOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -42,17 +94,15 @@ class SequenceExpandOpMaker : public framework::OpProtoAndCheckerMaker {
SequenceExpandOpMaker(OpProto* proto, OpAttrChecker* op_checker) SequenceExpandOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(Tensor or LoDTensor) The input(X) of this operator can be a " "(LoDTensor, default LoDTensor<float>) A 2-D LoDTensor whose lod "
"LoDTensor or a base Tensor."); "level is at most 1.");
AddInput("Y", AddInput("Y",
"(LoDTensor)The reference input(Y) of sequence_expand op." "(LoDTensor, default LoDTensor<float>) Referred LoDTensor whose "
"It must be a LoDTensor with k-level(k>0)." "lod (specified level) is referred by Input(X).");
"The input(X) will be expanded according to LOD of input(Y)."
"The element numbers of last level in input(Y) "
"must be equal to dims[0] of input(X).");
AddOutput("Out", AddOutput("Out",
"(LodTensor)The output of sequence_expand op." "(LodTensor, default LoDTensor<float>) Output LoDTensor which is "
"The lod of output will be as same as input(Y)'s lod."); "generated from Input(X) by referring lod of Input(Y).");
AddAttr<int>("ref_level", "Specify lod level of Input(Y).");
AddComment(R"DOC( AddComment(R"DOC(
Sequence Expand Operator. Sequence Expand Operator.
...@@ -129,12 +179,14 @@ class SequenceExpandOpGrad : public framework::OperatorWithKernel { ...@@ -129,12 +179,14 @@ class SequenceExpandOpGrad : public framework::OperatorWithKernel {
protected: protected:
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X")); PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Out")); PADDLE_ENFORCE(ctx->HasInput("Out"), "Input(Out) should not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"The input(Out@GRAD) should not be null"); "Input(Out@GRAD) should not be null.");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
auto x_grad_name = framework::GradVarName("X"); auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) { if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims); ctx->SetOutputDim(x_grad_name, x_dims);
} }
...@@ -149,7 +201,13 @@ REGISTER_OP(sequence_expand, ops::SequenceExpandOp, ops::SequenceExpandOpMaker, ...@@ -149,7 +201,13 @@ REGISTER_OP(sequence_expand, ops::SequenceExpandOp, ops::SequenceExpandOpMaker,
sequence_expand_grad, ops::SequenceExpandOpGrad); sequence_expand_grad, ops::SequenceExpandOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
sequence_expand, sequence_expand,
ops::SequenceExpandKernel<paddle::platform::CPUDeviceContext, float>); ops::SequenceExpandKernel<paddle::platform::CPUDeviceContext, float>,
ops::SequenceExpandKernel<paddle::platform::CPUDeviceContext, double>,
ops::SequenceExpandKernel<paddle::platform::CPUDeviceContext, int>,
ops::SequenceExpandKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
sequence_expand_grad, sequence_expand_grad,
ops::SequenceExpandGradKernel<paddle::platform::CPUDeviceContext, float>); ops::SequenceExpandGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::SequenceExpandGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::SequenceExpandGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::SequenceExpandGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
...@@ -18,7 +18,14 @@ limitations under the License. */ ...@@ -18,7 +18,14 @@ limitations under the License. */
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
sequence_expand, sequence_expand,
ops::SequenceExpandKernel<paddle::platform::CUDADeviceContext, float>); ops::SequenceExpandKernel<paddle::platform::CUDADeviceContext, float>,
ops::SequenceExpandKernel<paddle::platform::CUDADeviceContext, double>,
ops::SequenceExpandKernel<paddle::platform::CUDADeviceContext, int>,
ops::SequenceExpandKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
sequence_expand_grad, sequence_expand_grad,
ops::SequenceExpandGradKernel<paddle::platform::CUDADeviceContext, float>); ops::SequenceExpandGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::SequenceExpandGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::SequenceExpandGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::SequenceExpandGradKernel<paddle::platform::CUDADeviceContext,
int64_t>);
...@@ -28,33 +28,57 @@ class SequenceExpandKernel : public framework::OpKernel<T> { ...@@ -28,33 +28,57 @@ class SequenceExpandKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<LoDTensor>("X"); auto* x = context.Input<LoDTensor>("X");
auto* out = context.Output<LoDTensor>("Out");
const T* x_data = x->data<T>();
auto x_dims = x->dims();
auto* y = context.Input<LoDTensor>("Y"); auto* y = context.Input<LoDTensor>("Y");
PADDLE_ENFORCE(!y->lod().empty(), "y should have lod"); auto* out = context.Output<LoDTensor>("Out");
PADDLE_ENFORCE_EQ(static_cast<size_t>(x_dims[0]), int ref_level = context.Attr<int>("ref_level");
y->lod().back().size() - 1,
"The size of last lod level in Input(Y)"
"must be equal to dims[0] of Input(X).");
out->set_lod(y->lod());
auto* place =
context.template device_context<DeviceContext>().eigen_device();
size_t element_len = framework::product(x_dims) / x_dims[0];
T* out_data = out->mutable_data<T>(context.GetPlace());
auto out_starts = out->lod().back();
for (size_t i = 0; i < out_starts.size() - 1; i++) { auto& x_lod = x->lod();
int scale = out_starts[i + 1] - out_starts[i]; auto& y_lod = y->lod();
Eigen::TensorMap<
Eigen::Tensor<const T, 2, Eigen::RowMajor, Eigen::DenseIndex>> PADDLE_ENFORCE_GE(ref_level, 0,
x_t(x_data, 1, element_len); "Value of attribute `ref_level` should be greater or "
Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor, Eigen::DenseIndex>> "equal to 0.");
out_t(out_data, scale, element_len);
Eigen::array<int, 2> cast({{scale, 1}}); PADDLE_ENFORCE_LT(ref_level, y_lod.size(),
out_t.device(*place) = x_t.broadcast(cast); "Value of attribute `ref_level` should be smaller than "
x_data += element_len; "level number of Y's lod.");
out_data += element_len * scale;
if (y_lod[ref_level].size() < 1) {
framework::TensorCopy(*x, context.GetPlace(), out);
return;
}
if (x_lod.size() == 0) {
int out_start = 0;
for (size_t i = 1; i < y_lod[ref_level].size(); ++i) {
int repeat_num = y_lod[ref_level][i] - y_lod[ref_level][i - 1];
auto x_sub_tensor = x->Slice(i - 1, i);
for (size_t j = 0; j < repeat_num; ++j) {
auto out_sub_tensor = out->Slice(out_start, out_start + 1);
framework::TensorCopy(x_sub_tensor, context.GetPlace(),
&out_sub_tensor);
out_start++;
}
}
} else {
auto& out_lod = *out->mutable_lod();
out_lod.resize(1);
out_lod[0].resize(1);
out_lod[0][0] = 0;
int out_idx = 0;
for (size_t i = 1; i < y_lod[ref_level].size(); ++i) {
int repeat_num = y_lod[ref_level][i] - y_lod[ref_level][i - 1];
int x_seq_len = x_lod[0][i] - x_lod[0][i - 1];
auto x_sub_tensor = x->Slice(x_lod[0][i], x_lod[0][i - 1]);
for (size_t j = 0; j < repeat_num; ++j) {
auto out_sub_tensor =
out->Slice(out_lod[0][out_idx], out_lod[0][out_idx] + x_seq_len);
framework::TensorCopy(x_sub_tensor, context.GetPlace(),
&out_sub_tensor);
out_lod[0].push_back(out_lod[0][out_idx] + x_seq_len);
out_idx++;
}
}
} }
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册