提交 352fa41a 编写于 作者: Y yangyaming

Finish adapting forward.

上级 5a159f34
......@@ -17,7 +17,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
using framework::Tensor;
using framework::LoDTensor;
class SequenceExpandOp : public framework::OperatorWithKernel {
public:
......@@ -25,15 +25,67 @@ class SequenceExpandOp : public framework::OperatorWithKernel {
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"));
PADDLE_ENFORCE(ctx->HasOutput("Out"));
PADDLE_ENFORCE(ctx->HasInput("Y"));
framework::DDim out_dim;
auto y_dim = ctx->GetInputDim("Y");
out_dim = ctx->GetInputDim("X");
out_dim[0] = y_dim[0];
ctx->ShareLoD("Y", "Out");
ctx->SetOutputDim("Out", out_dim);
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SequenceExpandOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Y"),
"Input(Y) of SequenceExpandOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of SequenceExpandOp should not be null.");
auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(x_dims.size(), 2U,
"Dimension number of Input(X) should be 2.");
int ref_level = ctx->Attrs().Get<int>("ref_level");
if (ctx->IsRuntime()) {
framework::Variable* x_var =
boost::get<framework::Variable*>(ctx->GetInputVarPtrs("X")[0]);
framework::Variable* y_var =
boost::get<framework::Variable*>(ctx->GetInputVarPtrs("Y")[0]);
auto& x_lod = x_var->Get<LoDTensor>().lod();
auto& y_lod = y_var->Get<LoDTensor>().lod();
PADDLE_ENFORCE_LE(x_lod.size(), 1,
"Number of lod level of Input(X) should not be "
"greater than 1.");
PADDLE_ENFORCE(x_lod.size() == y_lod.size() || x_lod.size() == 0,
"Number of lod level of Input(X) either equal to 0 "
"or equal to that of Input(Y).");
int64_t out_first_dim = 0;
if (y_lod[ref_level].size() < 1) {
out_first_dim = x_dims[0];
} else {
if (x_lod.size() == 1) { // X is LoDTensor
for (size_t i = 1; i < y_lod[ref_level].size(); ++i) {
int x_seq_len = x_lod[0][i] - x_lod[0][i - 1];
out_first_dim +=
(y_lod[ref_level][i] - y_lod[ref_level][i - 1]) * x_seq_len;
}
} else { // X is normal Tensor
for (size_t i = 1; i < y_lod[ref_level].size(); ++i) {
out_first_dim += y_lod[ref_level][i] - y_lod[ref_level][i - 1];
}
}
}
ctx->SetOutputDim("Out", {out_first_dim, x_dims[1]});
} else {
framework::VarDesc* in_reader =
boost::get<framework::VarDesc*>(ctx->GetInputVarPtrs("Y")[0]);
int lod_level_num = in_reader->GetLoDLevels().size();
PADDLE_ENFORCE_GE(ref_level, 0,
"Level of referred lod should be greater or "
"equal to 0.");
PADDLE_ENFORCE_LT(ref_level, lod_level_num,
"Level of referred lod should be smaller than "
"level number of Input(Y).");
ctx->SetOutputDim("Out", {-1, x_dims[1]});
}
}
};
......@@ -42,17 +94,15 @@ class SequenceExpandOpMaker : public framework::OpProtoAndCheckerMaker {
SequenceExpandOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"(Tensor or LoDTensor) The input(X) of this operator can be a "
"LoDTensor or a base Tensor.");
"(LoDTensor, default LoDTensor<float>) A 2-D LoDTensor whose lod "
"level is at most 1.");
AddInput("Y",
"(LoDTensor)The reference input(Y) of sequence_expand op."
"It must be a LoDTensor with k-level(k>0)."
"The input(X) will be expanded according to LOD of input(Y)."
"The element numbers of last level in input(Y) "
"must be equal to dims[0] of input(X).");
"(LoDTensor, default LoDTensor<float>) Referred LoDTensor whose "
"lod (specified level) is referred by Input(X).");
AddOutput("Out",
"(LodTensor)The output of sequence_expand op."
"The lod of output will be as same as input(Y)'s lod.");
"(LodTensor, default LoDTensor<float>) Output LoDTensor which is "
"generated from Input(X) by referring lod of Input(Y).");
AddAttr<int>("ref_level", "Specify lod level of Input(Y).");
AddComment(R"DOC(
Sequence Expand Operator.
......@@ -129,12 +179,14 @@ class SequenceExpandOpGrad : public framework::OperatorWithKernel {
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"));
PADDLE_ENFORCE(ctx->HasInput("Out"));
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Out"), "Input(Out) should not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"The input(Out@GRAD) should not be null");
"Input(Out@GRAD) should not be null.");
auto x_dims = ctx->GetInputDim("X");
auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
}
......@@ -149,7 +201,13 @@ REGISTER_OP(sequence_expand, ops::SequenceExpandOp, ops::SequenceExpandOpMaker,
sequence_expand_grad, ops::SequenceExpandOpGrad);
REGISTER_OP_CPU_KERNEL(
sequence_expand,
ops::SequenceExpandKernel<paddle::platform::CPUDeviceContext, float>);
ops::SequenceExpandKernel<paddle::platform::CPUDeviceContext, float>,
ops::SequenceExpandKernel<paddle::platform::CPUDeviceContext, double>,
ops::SequenceExpandKernel<paddle::platform::CPUDeviceContext, int>,
ops::SequenceExpandKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
sequence_expand_grad,
ops::SequenceExpandGradKernel<paddle::platform::CPUDeviceContext, float>);
ops::SequenceExpandGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::SequenceExpandGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::SequenceExpandGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::SequenceExpandGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
......@@ -18,7 +18,14 @@ limitations under the License. */
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
sequence_expand,
ops::SequenceExpandKernel<paddle::platform::CUDADeviceContext, float>);
ops::SequenceExpandKernel<paddle::platform::CUDADeviceContext, float>,
ops::SequenceExpandKernel<paddle::platform::CUDADeviceContext, double>,
ops::SequenceExpandKernel<paddle::platform::CUDADeviceContext, int>,
ops::SequenceExpandKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
sequence_expand_grad,
ops::SequenceExpandGradKernel<paddle::platform::CUDADeviceContext, float>);
ops::SequenceExpandGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::SequenceExpandGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::SequenceExpandGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::SequenceExpandGradKernel<paddle::platform::CUDADeviceContext,
int64_t>);
......@@ -28,33 +28,57 @@ class SequenceExpandKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<LoDTensor>("X");
auto* out = context.Output<LoDTensor>("Out");
const T* x_data = x->data<T>();
auto x_dims = x->dims();
auto* y = context.Input<LoDTensor>("Y");
PADDLE_ENFORCE(!y->lod().empty(), "y should have lod");
PADDLE_ENFORCE_EQ(static_cast<size_t>(x_dims[0]),
y->lod().back().size() - 1,
"The size of last lod level in Input(Y)"
"must be equal to dims[0] of Input(X).");
out->set_lod(y->lod());
auto* place =
context.template device_context<DeviceContext>().eigen_device();
size_t element_len = framework::product(x_dims) / x_dims[0];
T* out_data = out->mutable_data<T>(context.GetPlace());
auto out_starts = out->lod().back();
auto* out = context.Output<LoDTensor>("Out");
int ref_level = context.Attr<int>("ref_level");
for (size_t i = 0; i < out_starts.size() - 1; i++) {
int scale = out_starts[i + 1] - out_starts[i];
Eigen::TensorMap<
Eigen::Tensor<const T, 2, Eigen::RowMajor, Eigen::DenseIndex>>
x_t(x_data, 1, element_len);
Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor, Eigen::DenseIndex>>
out_t(out_data, scale, element_len);
Eigen::array<int, 2> cast({{scale, 1}});
out_t.device(*place) = x_t.broadcast(cast);
x_data += element_len;
out_data += element_len * scale;
auto& x_lod = x->lod();
auto& y_lod = y->lod();
PADDLE_ENFORCE_GE(ref_level, 0,
"Value of attribute `ref_level` should be greater or "
"equal to 0.");
PADDLE_ENFORCE_LT(ref_level, y_lod.size(),
"Value of attribute `ref_level` should be smaller than "
"level number of Y's lod.");
if (y_lod[ref_level].size() < 1) {
framework::TensorCopy(*x, context.GetPlace(), out);
return;
}
if (x_lod.size() == 0) {
int out_start = 0;
for (size_t i = 1; i < y_lod[ref_level].size(); ++i) {
int repeat_num = y_lod[ref_level][i] - y_lod[ref_level][i - 1];
auto x_sub_tensor = x->Slice(i - 1, i);
for (size_t j = 0; j < repeat_num; ++j) {
auto out_sub_tensor = out->Slice(out_start, out_start + 1);
framework::TensorCopy(x_sub_tensor, context.GetPlace(),
&out_sub_tensor);
out_start++;
}
}
} else {
auto& out_lod = *out->mutable_lod();
out_lod.resize(1);
out_lod[0].resize(1);
out_lod[0][0] = 0;
int out_idx = 0;
for (size_t i = 1; i < y_lod[ref_level].size(); ++i) {
int repeat_num = y_lod[ref_level][i] - y_lod[ref_level][i - 1];
int x_seq_len = x_lod[0][i] - x_lod[0][i - 1];
auto x_sub_tensor = x->Slice(x_lod[0][i], x_lod[0][i - 1]);
for (size_t j = 0; j < repeat_num; ++j) {
auto out_sub_tensor =
out->Slice(out_lod[0][out_idx], out_lod[0][out_idx] + x_seq_len);
framework::TensorCopy(x_sub_tensor, context.GetPlace(),
&out_sub_tensor);
out_lod[0].push_back(out_lod[0][out_idx] + x_seq_len);
out_idx++;
}
}
}
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册