未验证 提交 2dd91dd5 编写于 作者: Y Yang Yu

Shrink State Operator

Used for shrink memories state in DyRNN. The height of state could
be shrinked after running a step block.
上级 25711eda
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/lod_tensor_array.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
class ArrayOp : public framework::OperatorBase {
public:
ArrayOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
protected:
size_t GetOffset(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const {
auto *i = scope.FindVar(Input("I"));
PADDLE_ENFORCE(i != nullptr, "I must be set");
auto &i_tensor = i->Get<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(i_tensor.numel(), 1);
size_t offset;
if (platform::is_gpu_place(i_tensor.place())) {
// FIXME: Avoid copy from GPU to CPU
framework::Tensor t;
t.CopyFrom(i_tensor, platform::CPUPlace(), dev_ctx);
dev_ctx.Wait();
offset = static_cast<size_t>(*t.data<int64_t>());
} else {
offset = static_cast<size_t>(*i_tensor.data<int64_t>());
}
return offset;
}
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/lod_rank_table.h"
#include "paddle/operators/array_operator.h"
#include "paddle/operators/math/math_function.h"
namespace paddle {
namespace operators {
class ShrinkStateOp : public ArrayOp {
public:
ShrinkStateOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: ArrayOp(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
auto *x_var = scope.FindVar(Input("X"));
PADDLE_ENFORCE(x_var != nullptr, "Input X must be set");
auto &x_tensor = x_var->Get<framework::LoDTensor>();
size_t offset = this->GetOffset(scope, dev_ctx);
auto *rank_table_var = scope.FindVar(Input("RankTable"));
PADDLE_ENFORCE(rank_table_var != nullptr, "RankTable must be set");
auto &rank_table = rank_table_var->Get<framework::LoDRankTable>();
int dst_num_rows = 0;
{
auto &rank_items = rank_table.items();
for (auto &rank_item : rank_items) {
if (offset < rank_item.length) {
++dst_num_rows;
} else {
break;
}
}
}
auto *out_var = scope.FindVar(Output("Out"));
PADDLE_ENFORCE(out_var != nullptr, "Output Out must be set");
auto &out_tensor = *out_var->GetMutable<framework::LoDTensor>();
if (dst_num_rows != 0) {
out_tensor.ShareDataWith(x_tensor.Slice(0, dst_num_rows));
}
}
};
class ShrinkStateOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
ShrinkStateOpProtoMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "");
AddInput("RankTable", "");
AddInput("I", "");
AddOutput("Out", "");
AddComment("");
}
};
class ShrinkStateOpInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE(context->HasInput("X"));
PADDLE_ENFORCE(context->HasInput("I"));
PADDLE_ENFORCE(context->HasInput("RankTable"));
context->SetOutputDim("Out", context->GetInputDim("X"));
}
};
class ShrinkStateGradOp : public ArrayOp {
public:
ShrinkStateGradOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: ArrayOp(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
auto *dout_var = scope.FindVar(Input(framework::GradVarName("Out")));
auto dx_name = Output(framework::GradVarName("X"));
auto *dx_var = scope.FindVar(dx_name);
PADDLE_ENFORCE(dx_var != nullptr, "Input Gradient should not be nullptr");
auto *x_var = scope.FindVar(Input("X"));
PADDLE_ENFORCE(x_var != nullptr);
auto &x_tensor = x_var->Get<framework::LoDTensor>();
auto &dx_tensor = *dx_var->GetMutable<framework::LoDTensor>();
dx_tensor.Resize(x_tensor.dims());
dx_tensor.mutable_data(x_tensor.place(), x_tensor.type());
if (dout_var == nullptr) { // dx_tensor fill zero
math::set_constant(dev_ctx, &dx_tensor, 0.0f);
} else {
auto &dout_tensor = dout_var->Get<framework::LoDTensor>();
auto height = dout_tensor.dims()[0];
dx_tensor.Slice(0, static_cast<int>(height))
.CopyFrom(dout_tensor, dout_tensor.place(), dev_ctx);
if (height < dout_tensor.dims()[0]) {
auto rest_tensor = dx_tensor.Slice(
static_cast<int>(height), static_cast<int>(dout_tensor.dims()[0]));
math::set_constant(dev_ctx, &rest_tensor, 0.0f);
}
}
}
};
class ShrikStateGradInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE(context->HasInput("X"));
PADDLE_ENFORCE(context->HasOutput(framework::GradVarName("X")));
context->SetOutputDim(framework::GradVarName("X"),
context->GetInputDim("X"));
}
};
class ShrinkStateGradOpMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDescBind> Apply() const override {
auto *op = new framework::OpDescBind();
op->SetType("shrink_state_grad");
op->SetInput("X", Input("X"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetAttrMap(Attrs());
return std::unique_ptr<framework::OpDescBind>(op);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(shrink_state, ops::ShrinkStateOp,
ops::ShrinkStateOpInferShape, ops::ShrinkStateOpProtoMaker,
ops::ShrinkStateGradOpMaker);
REGISTER_OPERATOR(shrink_state_grad, ops::ShrinkStateGradOp,
ops::ShrikStateGradInferShape);
......@@ -11,48 +11,18 @@
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/lod_tensor_array.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/array_operator.h"
namespace paddle {
namespace operators {
class ArrayOpBase : public framework::OperatorBase {
public:
ArrayOpBase(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {}
protected:
size_t GetOffset(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const {
auto *i = scope.FindVar(Input("I"));
PADDLE_ENFORCE(i != nullptr, "I must be set");
auto &i_tensor = i->Get<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(i_tensor.numel(), 1);
size_t offset;
if (platform::is_gpu_place(i_tensor.place())) {
// FIXME: Avoid copy from GPU to CPU
framework::Tensor t;
t.CopyFrom(i_tensor, platform::CPUPlace(), dev_ctx);
dev_ctx.Wait();
offset = static_cast<size_t>(*t.data<int64_t>());
} else {
offset = static_cast<size_t>(*i_tensor.data<int64_t>());
}
return offset;
}
};
class WriteToArrayOp : public ArrayOpBase {
class WriteToArrayOp : public ArrayOp {
public:
WriteToArrayOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: ArrayOpBase(type, inputs, outputs, attrs) {}
: ArrayOp(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
......@@ -115,6 +85,7 @@ class WriteToArrayInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDescBind &op_desc,
framework::BlockDescBind *block) const override {
VLOG(10) << "I am here?";
for (auto &out_var : op_desc.OutputArgumentNames()) {
VLOG(10) << "Set Variable " << out_var << " as LOD_TENSOR_ARRAY";
block->Var(out_var)->SetType(framework::VarDesc::LOD_TENSOR_ARRAY);
......@@ -122,13 +93,13 @@ class WriteToArrayInferVarType : public framework::VarTypeInference {
}
};
class ReadFromArrayOp : public ArrayOpBase {
class ReadFromArrayOp : public ArrayOp {
public:
ReadFromArrayOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: ArrayOpBase(type, inputs, outputs, attrs) {}
: ArrayOp(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
auto *x = scope.FindVar(Input("X"));
......
......@@ -801,13 +801,12 @@ def zeros(shape, dtype, main_program=None):
def increment(x, value=1.0, main_program=None):
helper = LayerHelper("increment", **locals())
tmp = helper.create_tmp_variable(dtype=x.data_type)
helper.append_op(
type='increment',
inputs={'X': [x]},
outputs={'Out': [tmp]},
outputs={'Out': [x]},
attrs={'step': value})
return tmp
return x
def array_write(x, i, array=None, main_program=None):
......@@ -838,3 +837,16 @@ def array_read(array, i, main_program=None):
'I': [i]},
outputs={'Out': [out]})
return out
def shrink_memory(x, i, table, main_program=None):
helper = LayerHelper('shrink_memory', **locals())
out = helper.create_tmp_variable(dtype=x.data_type)
helper.append_op(
type='shrink_state',
inputs={'X': [x],
'I': [i],
'RankTable': [table]},
outputs={'Out': [out]},
attrs={})
return out
import unittest
import paddle.v2.framework.core as core
from paddle.v2.framework.executor import Executor
import paddle.v2.framework.layers as layers
from paddle.v2.framework.backward import append_backward_ops
from paddle.v2.framework.framework import g_main_program
import numpy
class TestShrinkState(unittest.TestCase):
def test_shrink_state(self):
x = layers.data('x', shape=[100], data_type='float32')
x.stop_gradient = False
table = layers.lod_rank_table(x=x)
i = layers.zeros(dtype='int64', shape=[1])
mem1 = layers.shrink_memory(x=x, i=i, table=table)
i = layers.increment(x=i)
i.stop_gradient = True
mem2 = layers.shrink_memory(x=mem1, i=i, table=table)
i = layers.increment(x=i)
i.stop_gradient = True
mem3 = layers.shrink_memory(x=mem2, i=i, table=table)
cpu = core.CPUPlace()
tensor = core.LoDTensor()
tensor.set_lod([[0, 2, 5, 6]])
tensor_np = numpy.random.random(size=(3, 100)).astype('float32')
tensor.set(tensor_np, cpu)
exe = Executor(cpu)
outs = map(numpy.array,
exe.run(feed={'x': tensor}, fetch_list=[mem1, mem2, mem3]))
self.assertTrue(numpy.allclose(tensor_np[0:3], outs[0]))
self.assertTrue(numpy.allclose(tensor_np[0:2], outs[1]))
self.assertTrue(numpy.allclose(tensor_np[0:1], outs[2]))
mem3_mean = layers.mean(x=mem3)
append_backward_ops(loss=mem3_mean)
x_grad = map(numpy.array,
exe.run(feed={'x': tensor},
fetch_list=[
g_main_program.global_block().var('x@GRAD')
]))[0]
self.assertAlmostEqual(1.0, x_grad.sum(), delta=0.1)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册