提交 08cb472a 编写于 作者: Y yangyaming

Simplify the implementation.

上级 fc581bc5
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/sequence_reshape_op.h"
#include "paddle/framework/ddim.h"
namespace paddle {
namespace operators {
......@@ -26,9 +27,11 @@ class SequenceReshapeOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of SequenceReshapeOp should not be null.");
auto x_dims = ctx->GetInputDim("X");
auto x_numel = product(x_dims);
PADDLE_ENFORCE_EQ(x_dims.size(), 2U, "Rank of Input(X) should be 2.");
int dimension = ctx->Attrs().Get<int>("new_dim");
ctx->SetOutputDim("Out", {x_dims[0], static_cast<int64_t>(dimension)});
int new_dim = ctx->Attrs().Get<int>("new_dim");
ctx->SetOutputDim("Out",
{x_numel / new_dim, static_cast<int64_t>(new_dim)});
}
};
......@@ -54,16 +57,16 @@ example will help to illustrate the function of this operator:
x is a LoDTensor:
x.lod = [[0, 2, 6]]
x.data = [[0.1, 0.2], [0.3, 0.4],
[0.5, 0.6], [0.7, 0.8], [0.9, 1.0], [1.1, 1.2]]
x.data = [[1, 2], [3, 4],
[5, 6], [7, 8], [9, 10], [11, 12]]
x.dims = [6, 2]
set new_dim = 4
then out is a LoDTensor:
out.lod = [[0, 1, 3]]
out.data = [[0.1, 0.2, 0.3, 0.4],
[0.5, 0.6, 0.7, 0.8], [0.9, 1.0, 1.1, 1.2]]
out.lod = [[0, 1, 3]]
out.data = [[1, 2, 3, 4],
[5, 6, 7, 8], [9, 10, 11, 12]]
out.dims = [3, 4]
Currently, only 1-level LoDTensor is supported and please make sure (original
......@@ -82,8 +85,6 @@ class SequenceReshapeGradOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(
ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) of SequenceReshapeGradOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Out"),
"Input(Out) of SequenceReshapeGradOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SequenceReshapeGradOp should not be null.");
......@@ -101,7 +102,6 @@ class SequenceReshapeGradOpMaker : public framework::SingleGradOpDescMaker {
auto* op_desc_ptr = new framework::OpDesc();
op_desc_ptr->SetType("sequence_reshape_grad");
op_desc_ptr->SetInput("X", Input("X"));
op_desc_ptr->SetInput("Out", Output("Out"));
op_desc_ptr->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op_desc_ptr->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op_desc_ptr->SetAttrMap(Attrs());
......@@ -118,7 +118,13 @@ REGISTER_OPERATOR(sequence_reshape, ops::SequenceReshapeOp,
REGISTER_OPERATOR(sequence_reshape_grad, ops::SequenceReshapeGradOp);
REGISTER_OP_CPU_KERNEL(
sequence_reshape,
ops::SequenceReshapeKernel<paddle::platform::CPUDeviceContext, float>);
ops::SequenceReshapeKernel<paddle::platform::CPUDeviceContext, float>,
ops::SequenceReshapeKernel<paddle::platform::CPUDeviceContext, double>,
ops::SequenceReshapeKernel<paddle::platform::CPUDeviceContext, int>,
ops::SequenceReshapeKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
sequence_reshape_grad,
ops::SequenceReshapeGradKernel<paddle::platform::CPUDeviceContext, float>);
ops::SequenceReshapeGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::SequenceReshapeGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::SequenceReshapeGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::SequenceReshapeGradKernel<paddle::platform::CPUDeviceContext, int>);
......@@ -17,7 +17,14 @@ limitations under the License. */
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
sequence_reshape,
ops::SequenceReshapeKernel<paddle::platform::CUDADeviceContext, float>);
ops::SequenceReshapeKernel<paddle::platform::CUDADeviceContext, float>,
ops::SequenceReshapeKernel<paddle::platform::CUDADeviceContext, double>,
ops::SequenceReshapeKernel<paddle::platform::CUDADeviceContext, int>,
ops::SequenceReshapeKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
sequence_reshape_grad,
ops::SequenceReshapeGradKernel<paddle::platform::CUDADeviceContext, float>);
ops::SequenceReshapeGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::SequenceReshapeGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::SequenceReshapeGradKernel<paddle::platform::CUDADeviceContext,
int64_t>,
ops::SequenceReshapeGradKernel<paddle::platform::CUDADeviceContext, int>);
......@@ -28,8 +28,6 @@ class SequenceReshapeKernel : public framework::OpKernel<T> {
auto* out = context.Output<LoDTensor>("Out");
int out_width = context.Attr<int>("new_dim");
const T* p_in_data = in->data<T>();
auto in_dims = in->dims();
int64_t in_width = in_dims[1];
auto& in_lod = in->lod();
......@@ -43,53 +41,29 @@ class SequenceReshapeKernel : public framework::OpKernel<T> {
auto in_lod_l0 = in_lod[0];
int seq_num = in_lod_l0.size() - 1;
auto& out_lod = *out->mutable_lod();
out_lod.resize(1);
out_lod[0].clear();
out_lod[0].push_back(0);
for (int i = 0; i < seq_num; ++i) {
size_t seq_len = in_lod_l0[i + 1] - in_lod_l0[i];
size_t offset = 0;
offset = (seq_len * in_width) / out_width;
PADDLE_ENFORCE_EQ(offset * out_width, seq_len * in_width,
"Please make sure (sequence_length * dimension) can be "
"divided by new_dim with no remainder for each "
"sequence. The %dth sequence is invalid.",
i + 1);
PADDLE_ENFORCE_GT(offset, 0,
"Illegal operation, length of the %dth sequence become "
"to 0 after reshaped.",
i + 1);
out_lod[0].push_back(out_lod[0].back() + offset);
if (in_width == out_width) {
out->set_lod(in->lod());
} else {
auto& out_lod = *out->mutable_lod();
out_lod.resize(1);
out_lod[0].clear();
out_lod[0].push_back(0);
for (int i = 0; i < seq_num; ++i) {
size_t seq_len = in_lod_l0[i + 1] - in_lod_l0[i];
size_t offset = 0;
offset = (seq_len * in_width) / out_width;
PADDLE_ENFORCE_EQ(offset * out_width, seq_len * in_width,
"Please make sure (sequence_length * dimension) can "
"be divided by new_dim with no remainder for each "
"sequence. The %dth sequence is invalid.",
i + 1);
out_lod[0].push_back(out_lod[0].back() + offset);
}
}
out->mutable_data<T>(context.GetPlace());
out->Resize({static_cast<int64_t>(out_lod[0].back()), out_width});
T* p_out_data = out->mutable_data<T>(context.GetPlace());
math::set_constant(context.device_context(), out, 0.0f);
for (int i = 0; i < seq_num; ++i) {
size_t in_offset = in_lod_l0[i] * in_width;
size_t out_offset = out_lod[0][i] * out_width;
size_t in_count = (in_lod_l0[i + 1] - in_lod_l0[i]) * in_width;
size_t out_count = (out_lod[0][i + 1] - out_lod[0][i]) * out_width;
size_t bytes = sizeof(T) * std::min(in_count, out_count);
if (platform::is_cpu_place(context.GetPlace())) {
memory::Copy(boost::get<platform::CPUPlace>(context.GetPlace()),
p_out_data + out_offset,
boost::get<platform::CPUPlace>(context.GetPlace()),
p_in_data + in_offset, bytes);
} else {
#ifdef PADDLE_WITH_CUDA
auto& dev_ctx =
context.template device_context<platform::CUDADeviceContext>();
memory::Copy(boost::get<platform::CUDAPlace>(context.GetPlace()),
p_out_data + out_offset,
boost::get<platform::CUDAPlace>(context.GetPlace()),
p_in_data + in_offset, bytes, dev_ctx.stream());
#endif
}
}
framework::Copy(*in, context.GetPlace(), out);
out->Resize({static_cast<int64_t>(out->lod()[0].back()), out_width});
}
};
......@@ -98,45 +72,14 @@ class SequenceReshapeGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x_tensor_ptr = context.Input<LoDTensor>("X");
auto* out_tensor_ptr = context.Input<LoDTensor>("Out");
auto* out_grad_tensor_ptr =
auto* outg_tensor_ptr =
context.Input<LoDTensor>(framework::GradVarName("Out"));
auto* x_grad_tensor_ptr =
auto* xg_tensor_ptr =
context.Output<LoDTensor>(framework::GradVarName("X"));
T* p_x_grad_data = x_grad_tensor_ptr->mutable_data<T>(context.GetPlace());
const T* p_out_grad_data = out_grad_tensor_ptr->data<T>();
auto& x_lod = x_tensor_ptr->lod();
int seq_num = x_lod[0].size() - 1;
int x_width = x_tensor_ptr->dims()[1];
auto& out_lod = out_tensor_ptr->lod();
int out_width = out_tensor_ptr->dims()[1];
math::set_constant(context.device_context(), x_grad_tensor_ptr, 0.0f);
for (int i = 0; i < seq_num; ++i) {
size_t src_offset = out_lod[0][i] * out_width;
size_t dst_offset = x_lod[0][i] * x_width;
size_t src_count = (out_lod[0][i + 1] - out_lod[0][i]) * out_width;
size_t dst_count = (x_lod[0][i + 1] - x_lod[0][i]) * x_width;
size_t bytes = sizeof(T) * std::min(src_count, dst_count);
if (platform::is_cpu_place(context.GetPlace())) {
memory::Copy(boost::get<platform::CPUPlace>(context.GetPlace()),
p_x_grad_data + dst_offset,
boost::get<platform::CPUPlace>(context.GetPlace()),
p_out_grad_data + src_offset, bytes);
} else {
#ifdef PADDLE_WITH_CUDA
auto& dev_ctx =
context.template device_context<platform::CUDADeviceContext>();
memory::Copy(boost::get<platform::CUDAPlace>(context.GetPlace()),
p_x_grad_data + dst_offset,
boost::get<platform::CUDAPlace>(context.GetPlace()),
p_out_grad_data + src_offset, bytes, dev_ctx.stream());
#endif
}
}
xg_tensor_ptr->mutable_data<T>(context.GetPlace());
framework::Copy(*outg_tensor_ptr, context.GetPlace(), xg_tensor_ptr);
xg_tensor_ptr->Resize(x_tensor_ptr->dims());
}
};
......
......@@ -40,14 +40,7 @@ class TestSequenceReshape(OpTest):
assert int(offset) * dimension == seq_len * x_width
out_lod[0].append(out_lod[0][-1] + int(offset))
out = np.zeros(shape=(out_lod[0][-1], dimension)).astype('float32')
for i in xrange(len(x_lod[0]) - 1):
x_offset = x_lod[0][i] * x_width
out_offset = out_lod[0][i] * dimension
out_count = (out_lod[0][i + 1] - out_lod[0][i]) * dimension
x_count = (x_lod[0][i + 1] - x_lod[0][i]) * x_width
count = min(out_count, x_count)
out.ravel()[out_offset:out_offset + count] = x.ravel()[
x_offset:x_offset + count]
out.ravel()[:] = x.ravel()[:]
return out, out_lod
def test_check_output(self):
......@@ -72,5 +65,20 @@ class TestSequenceReshape_reduce(TestSequenceReshape):
self.outputs = {'Out': (out, out_lod)}
class TestSequenceReshape_same(TestSequenceReshape):
def setUp(self):
self.op_type = 'sequence_reshape'
dimension = 12
x_lod = [[0, 4, 6, 8, 12]]
x = np.random.uniform(0.1, 1, [12, 12]).astype('float32')
self.inputs = {'X': (x, x_lod)}
self.attrs = {'new_dim': dimension}
out, out_lod = self.compute_output(x, x_lod, dimension)
self.outputs = {'Out': (out, out_lod)}
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册