未验证 提交 e579dc1c 编写于 作者: W whs 提交者: GitHub

Merge pull request #6708 from wanghaoshuang/rename_seq

Rename seq to sequence in sequence_expand_op
...@@ -12,14 +12,14 @@ ...@@ -12,14 +12,14 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/operators/seq_expand_op.h" #include "paddle/operators/sequence_expand_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using framework::Tensor; using framework::Tensor;
class SeqExpandOp : public framework::OperatorWithKernel { class SequenceExpandOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -35,25 +35,25 @@ class SeqExpandOp : public framework::OperatorWithKernel { ...@@ -35,25 +35,25 @@ class SeqExpandOp : public framework::OperatorWithKernel {
} }
}; };
class SeqExpandOpMaker : public framework::OpProtoAndCheckerMaker { class SequenceExpandOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SeqExpandOpMaker(framework::OpProto* proto, SequenceExpandOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker) framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(Tensor or LoDTensor) The input(X) of this operator can be a " "(Tensor or LoDTensor) The input(X) of this operator can be a "
"LoDTensor or a base Tensor."); "LoDTensor or a base Tensor.");
AddInput("Y", AddInput("Y",
"(LoDTensor)The reference input(Y) of seq_expand op." "(LoDTensor)The reference input(Y) of sequence_expand op."
"It must be a LoDTensor with k-level(k>0)." "It must be a LoDTensor with k-level(k>0)."
"The input(X) will be expanded according to LOD of input(Y)." "The input(X) will be expanded according to LOD of input(Y)."
"The element numbers of last level in input(Y) " "The element numbers of last level in input(Y) "
"must be equal to dims[0] of input(X)."); "must be equal to dims[0] of input(X).");
AddOutput("Out", AddOutput("Out",
"(LodTensor)The output of seq_expand op." "(LodTensor)The output of sequence_expand op."
"The lod of output will be as same as input(Y)'s lod."); "The lod of output will be as same as input(Y)'s lod.");
AddComment(R"DOC( AddComment(R"DOC(
Seq Expand Operator. Sequence Expand Operator.
This operator expands input(X) according to LOD of input(Y). This operator expands input(X) according to LOD of input(Y).
Following are cases to better explain how this works: Following are cases to better explain how this works:
...@@ -124,7 +124,7 @@ then we get 2-level LoDTensor ...@@ -124,7 +124,7 @@ then we get 2-level LoDTensor
} }
}; };
class SeqExpandOpGrad : public framework::OperatorWithKernel { class SequenceExpandOpGrad : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -146,11 +146,11 @@ class SeqExpandOpGrad : public framework::OperatorWithKernel { ...@@ -146,11 +146,11 @@ class SeqExpandOpGrad : public framework::OperatorWithKernel {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(seq_expand, ops::SeqExpandOp, ops::SeqExpandOpMaker, REGISTER_OP(sequence_expand, ops::SequenceExpandOp, ops::SequenceExpandOpMaker,
seq_expand_grad, ops::SeqExpandOpGrad); sequence_expand_grad, ops::SequenceExpandOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
seq_expand, sequence_expand,
ops::SeqExpandKernel<paddle::platform::CPUDeviceContext, float>); ops::SequenceExpandKernel<paddle::platform::CPUDeviceContext, float>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
seq_expand_grad, sequence_expand_grad,
ops::SeqExpandGradKernel<paddle::platform::CPUDeviceContext, float>); ops::SequenceExpandGradKernel<paddle::platform::CPUDeviceContext, float>);
...@@ -13,12 +13,12 @@ ...@@ -13,12 +13,12 @@
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include "paddle/operators/seq_expand_op.h" #include "paddle/operators/sequence_expand_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
seq_expand, sequence_expand,
ops::SeqExpandKernel<paddle::platform::CUDADeviceContext, float>); ops::SequenceExpandKernel<paddle::platform::CUDADeviceContext, float>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
seq_expand_grad, sequence_expand_grad,
ops::SeqExpandGradKernel<paddle::platform::CUDADeviceContext, float>); ops::SequenceExpandGradKernel<paddle::platform::CUDADeviceContext, float>);
...@@ -24,7 +24,7 @@ namespace operators { ...@@ -24,7 +24,7 @@ namespace operators {
using LoDTensor = framework::LoDTensor; using LoDTensor = framework::LoDTensor;
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class SeqExpandKernel : public framework::OpKernel<T> { class SequenceExpandKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<LoDTensor>("X"); auto* x = context.Input<LoDTensor>("X");
...@@ -71,7 +71,7 @@ class SeqExpandKernel : public framework::OpKernel<T> { ...@@ -71,7 +71,7 @@ class SeqExpandKernel : public framework::OpKernel<T> {
* *
* */ * */
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class SeqExpandGradKernel : public framework::OpKernel<T> { class SequenceExpandGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* d_out = context.Input<LoDTensor>(framework::GradVarName("Out")); auto* d_out = context.Input<LoDTensor>(framework::GradVarName("Out"));
......
...@@ -3,7 +3,7 @@ import numpy as np ...@@ -3,7 +3,7 @@ import numpy as np
from op_test import OpTest from op_test import OpTest
class TestSeqExpand(OpTest): class TestSequenceExpand(OpTest):
def set_data(self): def set_data(self):
x_data = np.random.uniform(0.1, 1, [3, 1]).astype('float32') x_data = np.random.uniform(0.1, 1, [3, 1]).astype('float32')
y_data = np.random.uniform(0.1, 1, [8, 1]).astype('float32') y_data = np.random.uniform(0.1, 1, [8, 1]).astype('float32')
...@@ -21,7 +21,7 @@ class TestSeqExpand(OpTest): ...@@ -21,7 +21,7 @@ class TestSeqExpand(OpTest):
self.outputs = {'Out': out} self.outputs = {'Out': out}
def setUp(self): def setUp(self):
self.op_type = 'seq_expand' self.op_type = 'sequence_expand'
self.set_data() self.set_data()
self.compute() self.compute()
...@@ -32,7 +32,7 @@ class TestSeqExpand(OpTest): ...@@ -32,7 +32,7 @@ class TestSeqExpand(OpTest):
self.check_grad(["X"], "Out") self.check_grad(["X"], "Out")
class TestSeqExpandCase1(TestSeqExpand): class TestSequenceExpandCase1(TestSequenceExpand):
def set_data(self): def set_data(self):
x_data = np.random.uniform(0.1, 1, [5, 1]).astype('float32') x_data = np.random.uniform(0.1, 1, [5, 1]).astype('float32')
x_lod = [[0, 2, 5]] x_lod = [[0, 2, 5]]
...@@ -41,7 +41,7 @@ class TestSeqExpandCase1(TestSeqExpand): ...@@ -41,7 +41,7 @@ class TestSeqExpandCase1(TestSeqExpand):
self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod)} self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod)}
class TestSeqExpandCase2(TestSeqExpand): class TestSequenceExpandCase2(TestSequenceExpand):
def set_data(self): def set_data(self):
x_data = np.random.uniform(0.1, 1, [1, 2, 2]).astype('float32') x_data = np.random.uniform(0.1, 1, [1, 2, 2]).astype('float32')
x_lod = [[0, 1]] x_lod = [[0, 1]]
...@@ -50,7 +50,7 @@ class TestSeqExpandCase2(TestSeqExpand): ...@@ -50,7 +50,7 @@ class TestSeqExpandCase2(TestSeqExpand):
self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod)} self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod)}
class TestSeqExpandCase3(TestSeqExpand): class TestSequenceExpandCase3(TestSequenceExpand):
def set_data(self): def set_data(self):
x_data = np.random.uniform(0.1, 1, [4, 1]).astype('float32') x_data = np.random.uniform(0.1, 1, [4, 1]).astype('float32')
x_lod = [[0, 1, 2, 3, 4]] x_lod = [[0, 1, 2, 3, 4]]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册