未验证 提交 484cff6e 编写于 作者: Y Yang yaming 提交者: GitHub

Merge pull request #9204 from pkuyym/fix-9171

Enhance LoDResetOp and add python wrapper
......@@ -22,17 +22,16 @@ class LoDResetOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
// input check
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of LoDResetOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of LoDResetOp should not be null.");
// If target LoD is not set form Input(), then it must be set from Attr().
if (!ctx->HasInput("TargetLoD")) {
if (!ctx->HasInput("Y")) {
auto level0 = ctx->Attrs().Get<std::vector<int>>("target_lod");
PADDLE_ENFORCE(level0.size() > 1,
"Target LoD is not found, should be set to be a valid one "
"through Input() or Attr().");
PADDLE_ENFORCE_GT(level0.size(), 1,
"If Input(Y) not provided, the target lod should be "
"specified by attribute `target_lod`.");
}
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
}
......@@ -50,36 +49,77 @@ class LoDResetOpMaker : public framework::OpProtoAndCheckerMaker {
public:
LoDResetOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(LoDTensor) The input tensor of lod_reset operator.");
AddInput("TargetLoD",
"(Tensor, optional) The target level 0 LoD from Input().")
AddInput("X",
"(Tensor, LoDTensor) Input variable of LoDResetOp which "
"could be a Tensor or LoDTensor, where the data of output "
"variable inherits from.");
AddInput("Y",
"(Tensor, LoDTensor, optional) If provided and Y is LoDTensor, "
"lod of Input(Y) would be considered as the target lod first, "
"otherwise data of Input(Y) would be considered as the "
"target lod.")
.AsDispensable();
AddOutput("Out", "(LoDTensor) The output tensor of lod_reset operator.");
AddOutput("Out",
"(LoDTensor) Output variable of LoDResetOp which should be a "
"LoDTensor.");
AddAttr<std::vector<int>>("target_lod",
"The target level 0 LoD from Attr().")
.SetDefault(std::vector<int>{});
AddComment(R"DOC(LoDReset operator
Reset LoD of Input(X) into a new one specified by Input(TargetLoD) or
Attr(target_lod), or set LoD for Input(X) if it doesn't have one.
Currently the lod_reset operator only supports the reset of level 0 LoD.
At least one of Input(TargetLoD) and Attr(target_lod) must be set,
and if both of them are set, Input(TargetLoD) will be chosen as the
target LoD.
Set LoD of `X` to a new one specified by `Y` or attribute `target_lod`. When `Y`
provided and `Y` is a LoDTensor, `Y.lod` would be considered as target LoD
first, otherwise `Y.data` would be considered as target LoD. If `Y` is not
provided, target LoD should be specified by attribute `target_lod`.
If target LoD is specified by `Y.data` or `target_lod`, only one level LoD
is supported.
Example 1:
Given a 1-level LoDTensor input(X):
X.lod = [[ 0, 2, 5 6 ]]
X.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]]
X.dims = [6, 1]
attr(target_lod): [0, 4, 6]
then we get a 1-level LoDTensor:
Out.lod = [[ 0, 4, 6 ]]
Out.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]]
Out.dims = [6, 1]
Example 2:
An example:
Given a float LoDTensor X with shape (6, 1), its transpose form represents
Given a 1-level LoDTensor input(X):
X.lod = [[ 0, 2, 5 6 ]]
X.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]]
X.dims = [6, 1]
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
input(Y) is a Tensor:
Y.data = [[0, 2, 6]]
Y.dims = [1, 3]
with LoD = [[0, 2, 5, 6]] and the three (transposed) sequences look like
then we get a 1-level LoDTensor:
Out.lod = [[ 0, 2, 6 ]]
Out.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]]
Out.dims = [6, 1]
[1.0, 2.0], [3.0, 4.0, 5.0], [6.0].
Example 3:
If target LoD = [0, 4, 6], the lod_reset operator will reset the LoD and
the sequences that the LoDTensor Output(Out) contains becomes:
Given a 1-level LoDTensor input(X):
X.lod = [[ 0, 2, 5 6 ]]
X.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]]
X.dims = [6, 1]
[1.0, 2.0, 3.0, 4.0], [5.0, 6.0].
input(Y) is a 2-level LoDTensor:
Y.lod = [[0, 2, 4], [0, 2, 5, 6]]
Y.data = [[1.1], [2.1], [3.1], [4.1], [5.1], [6.1]]
Y.dims = [6, 1]
then we get a 2-level LoDTensor:
Out.lod = [[0, 2, 4], [0, 2, 5, 6]]
Out.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]]
Out.dims = [6, 1]
)DOC");
}
......@@ -90,10 +130,16 @@ class LoDResetGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) shouldn't be null.");
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of LoDResetGradOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) shouldn't be null.");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
"Input(Out@Grad) of LoDResetGradOp should not be null.");
auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ x_grad_name);
}
}
protected:
......@@ -111,9 +157,13 @@ class LoDResetGradOp : public framework::OperatorWithKernel {
namespace ops = paddle::operators;
REGISTER_OP(lod_reset, ops::LoDResetOp, ops::LoDResetOpMaker, lod_reset_grad,
ops::LoDResetGradOp);
REGISTER_OP_CPU_KERNEL(lod_reset,
ops::LoDResetKernel<paddle::platform::CPUPlace, float>,
ops::LoDResetKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(
lod_reset, ops::LoDResetKernel<paddle::platform::CPUPlace, float>,
ops::LoDResetKernel<paddle::platform::CPUPlace, double>,
ops::LoDResetKernel<paddle::platform::CPUPlace, int>,
ops::LoDResetKernel<paddle::platform::CPUPlace, int64_t>);
REGISTER_OP_CPU_KERNEL(
lod_reset_grad, ops::LoDResetGradKernel<paddle::platform::CPUPlace, float>,
ops::LoDResetGradKernel<paddle::platform::CPUPlace, double>);
ops::LoDResetGradKernel<paddle::platform::CPUPlace, double>,
ops::LoDResetGradKernel<paddle::platform::CPUPlace, int>,
ops::LoDResetGradKernel<paddle::platform::CPUPlace, int64_t>);
......@@ -18,8 +18,12 @@ namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
lod_reset, ops::LoDResetKernel<paddle::platform::CUDADeviceContext, float>,
ops::LoDResetKernel<paddle::platform::CUDADeviceContext, double>);
ops::LoDResetKernel<paddle::platform::CUDADeviceContext, double>,
ops::LoDResetKernel<paddle::platform::CUDADeviceContext, int>,
ops::LoDResetKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
lod_reset_grad,
ops::LoDResetGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::LoDResetGradKernel<paddle::platform::CUDADeviceContext, double>);
ops::LoDResetGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::LoDResetGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::LoDResetGradKernel<paddle::platform::CUDADeviceContext, int64_t>);
......@@ -26,35 +26,46 @@ class LoDResetKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const {
auto* out = ctx.Output<framework::LoDTensor>("Out");
auto* in = ctx.Input<framework::LoDTensor>("X");
auto* lod_t = ctx.Input<framework::Tensor>("TargetLoD");
auto* lod_t = ctx.Input<framework::LoDTensor>("Y");
out->ShareDataWith(*in);
std::vector<int> level0;
if (lod_t) {
auto* lod = lod_t->data<int>();
if (platform::is_gpu_place(ctx.GetPlace())) {
framework::Tensor lod_cpu;
framework::TensorCopy(*lod_t, platform::CPUPlace(),
ctx.device_context(), &lod_cpu);
lod = lod_cpu.data<int>();
if (lod_t->lod().size() > 0) {
auto y_lod = lod_t->lod();
auto last_level = y_lod[y_lod.size() - 1];
PADDLE_ENFORCE_EQ(last_level.back(), in->dims()[0],
"Last value of `Y`'s last level LoD should be equal "
"to the first dimension of `X`");
out->set_lod(y_lod);
return; // early return, since lod already set
} else {
auto* lod = lod_t->data<int>();
if (platform::is_gpu_place(ctx.GetPlace())) {
framework::Tensor lod_cpu;
framework::TensorCopy(*lod_t, platform::CPUPlace(),
ctx.device_context(), &lod_cpu);
lod = lod_cpu.data<int>();
}
level0 = std::vector<int>(lod, lod + lod_t->numel());
}
level0 = std::vector<int>(lod, lod + lod_t->numel());
} else {
level0 = ctx.Attr<std::vector<int>>("target_lod");
}
PADDLE_ENFORCE(level0.size() > 1UL,
"The size of target LoD should be greater than 1.");
PADDLE_ENFORCE(level0[0] == 0,
"Target LoD should be a vector starting from 0.");
PADDLE_ENFORCE(level0.back() == in->dims()[0],
"Target LoD should be a vector end with the "
"first dimension of Input(X).");
PADDLE_ENFORCE_GT(level0.size(), 1UL,
"Size of target LoD should be greater than 1.");
PADDLE_ENFORCE_EQ(level0[0], 0,
"Target LoD should be a vector starting from 0.");
PADDLE_ENFORCE_EQ(level0.back(), in->dims()[0],
"Target LoD should be a vector end with the "
"first dimension of Input(X).");
for (size_t i = 0; i < level0.size() - 1; ++i) {
PADDLE_ENFORCE(level0[i + 1] > level0[i],
"Target LoD should be an ascending vector.");
}
out->ShareDataWith(*in);
// cast level0 to size_t
std::vector<size_t> ulevel0(level0.size(), 0);
std::transform(level0.begin(), level0.end(), ulevel0.begin(),
......
......@@ -73,6 +73,7 @@ __all__ = [
'smooth_l1',
'one_hot',
'autoincreased_step_counter',
'lod_reset',
]
......@@ -2225,7 +2226,7 @@ def reduce_prod(input, dim=None, keep_dim=False, name=None):
keep_dim (bool|False): Whether to reserve the reduced dimension in the
output Tensor. The result tensor will have one fewer dimension
than the :attr:`input` unless :attr:`keep_dim` is true.
name(str|None): A name for this layer(optional). If set None, the
name(str|None): A name for this layer(optional). If set None, the
layer will be named automatically.
Returns:
......@@ -2241,7 +2242,7 @@ def reduce_prod(input, dim=None, keep_dim=False, name=None):
fluid.layers.reduce_prod(x) # [0.0002268]
fluid.layers.reduce_prod(x, dim=0) # [0.02, 0.06, 0.3, 0.63]
fluid.layers.reduce_prod(x, dim=-1) # [0.027, 0.0084]
fluid.layers.reduce_prod(x, dim=1,
fluid.layers.reduce_prod(x, dim=1,
keep_dim=True) # [[0.027], [0.0084]]
"""
helper = LayerHelper('reduce_prod', **locals())
......@@ -3292,3 +3293,98 @@ def autoincreased_step_counter(counter_name=None, begin=1, step=1):
counter.stop_gradient = True
return counter
def lod_reset(x, y=None, target_lod=None):
"""
LoD Reset Operator. Set LoD of **x** to a new one specified by **y** or
**target_lod**. When **y** provided, **y.lod** would be considered as target
LoD first, otherwise **y.data** would be considered as target LoD. If **y**
is not provided, target LoD should be specified by **target_lod**.
If target LoD is specified by **Y.data** or **target_lod**, only one level
LoD is supported.
.. code-block:: text
* Example 1:
Given a 1-level LoDTensor x:
x.lod = [[ 0, 2, 5 6 ]]
x.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]]
x.dims = [6, 1]
target_lod: [0, 4, 6]
then we get a 1-level LoDTensor:
out.lod = [[ 0, 4, 6 ]]
out.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]]
out.dims = [6, 1]
* Example 2:
Given a 1-level LoDTensor x:
x.lod = [[ 0, 2, 5 6 ]]
x.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]]
x.dims = [6, 1]
y is a Tensor:
y.data = [[0, 2, 6]]
y.dims = [1, 3]
then we get a 1-level LoDTensor:
out.lod = [[ 0, 2, 6 ]]
out.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]]
out.dims = [6, 1]
* Example 3:
Given a 1-level LoDTensor x:
x.lod = [[ 0, 2, 5 6 ]]
x.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]]
x.dims = [6, 1]
y is a 2-level LoDTensor:
y.lod = [[0, 2, 4], [0, 2, 5, 6]]
y.data = [[1.1], [2.1], [3.1], [4.1], [5.1], [6.1]]
y.dims = [6, 1]
then we get a 2-level LoDTensor:
out.lod = [[0, 2, 4], [0, 2, 5, 6]]
out.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]]
out.dims = [6, 1]
Args:
x (Variable): Input variable which could be a Tensor or LodTensor.
y (Variable|None): If provided, output's LoD would be derived from y.
target_lod (list|tuple|None): One level LoD which should be considered
as target LoD when y not provided.
Returns:
Variable: Output variable with LoD specified by this operator.
Raises:
ValueError: If y and target_lod are both None.
Examples:
.. code-block:: python
x = layers.data(name='x', shape=[10])
y = layers.data(name='y', shape=[10, 20], lod_level=2)
out = layers.lod_reset(x=x, y=y)
"""
helper = LayerHelper("lod_reset", **locals())
out = helper.create_tmp_variable(dtype=x.dtype)
if y is not None:
helper.append_op(
type="lod_reset", inputs={'X': x,
'Y': y}, outputs={'Out': out})
elif target_lod is not None:
helper.append_op(
type="lod_reset",
inputs={'X': x},
attrs={'target_lod': target_lod},
outputs={'Out': out})
else:
raise ValueError("y and target_lod should not be both None.")
return out
......@@ -327,6 +327,15 @@ class TestBook(unittest.TestCase):
self.assertIsNotNone(loss)
print(str(program))
def test_lod_reset(self):
program = Program()
with program_guard(program):
x = layers.data(name='x', shape=[10], dtype='float32')
y = layers.data(
name='y', shape=[10, 20], dtype='float32', lod_level=2)
print(layers.lod_reset(x=x, y=y))
print(str(program))
if __name__ == '__main__':
unittest.main()
......@@ -42,7 +42,7 @@ class TestLodResetOpByInput(OpTest):
target_lod_0 = [0, 4, 7, 10]
self.inputs = {
'X': (x, lod),
'TargetLoD': np.array([target_lod_0]).astype('int32')
'Y': np.array([target_lod_0]).astype('int32')
}
self.outputs = {'Out': (x, [target_lod_0])}
......@@ -50,7 +50,7 @@ class TestLodResetOpByInput(OpTest):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out", no_grad_set=set("TargetLoD"))
self.check_grad(["X"], "Out", no_grad_set=set("Y"))
class TestLodResetOpBoth(OpTest):
......@@ -62,7 +62,7 @@ class TestLodResetOpBoth(OpTest):
target_lod_0_in = [0, 4, 7, 10]
self.inputs = {
'X': (x, lod),
'TargetLoD': np.array(target_lod_0_in).astype('int32')
'Y': np.array(target_lod_0_in).astype('int32')
}
self.attrs = {'target_lod': target_lod_0_attr}
self.outputs = {'Out': (x, [target_lod_0_in])}
......@@ -71,7 +71,24 @@ class TestLodResetOpBoth(OpTest):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out", no_grad_set=set("TargetLoD"))
self.check_grad(["X"], "Out", no_grad_set=set("Y"))
class TestLodResetOpYIsLoDTensor(OpTest):
def setUp(self):
self.op_type = "lod_reset"
x = np.random.random((10, 20)).astype("float32")
lod = [[0, 3, 5, 10]]
y = np.random.random((10, 10)).astype("float32")
target_lod_0 = [[0, 4, 7, 10]]
self.inputs = {'X': (x, lod), 'Y': (y, target_lod_0)}
self.outputs = {'Out': (x, target_lod_0)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out", no_grad_set=set("Y"))
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册