未验证 提交 2a76b42e 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #5419 from reyoung/feature/shrink_memory_op

Feature/shrink memory op
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/lod_tensor_array.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
class ArrayOp : public framework::OperatorBase {
public:
ArrayOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
protected:
size_t GetOffset(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const {
auto *i = scope.FindVar(Input("I"));
PADDLE_ENFORCE(i != nullptr, "I must be set");
auto &i_tensor = i->Get<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(i_tensor.numel(), 1);
size_t offset;
if (platform::is_gpu_place(i_tensor.place())) {
// FIXME: Avoid copy from GPU to CPU
framework::Tensor t;
t.CopyFrom(i_tensor, platform::CPUPlace(), dev_ctx);
dev_ctx.Wait();
offset = static_cast<size_t>(*t.data<int64_t>());
} else {
offset = static_cast<size_t>(*i_tensor.data<int64_t>());
}
return offset;
}
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/lod_rank_table.h"
#include "paddle/operators/array_operator.h"
#include "paddle/operators/math/math_function.h"
namespace paddle {
namespace operators {
class ShrinkRNNMemoryOp : public ArrayOp {
public:
ShrinkRNNMemoryOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: ArrayOp(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
auto *x_var = scope.FindVar(Input("X"));
PADDLE_ENFORCE(x_var != nullptr, "Input X must be set");
auto &x_tensor = x_var->Get<framework::LoDTensor>();
size_t offset = this->GetOffset(scope, dev_ctx);
auto *rank_table_var = scope.FindVar(Input("RankTable"));
PADDLE_ENFORCE(rank_table_var != nullptr, "RankTable must be set");
auto &rank_table = rank_table_var->Get<framework::LoDRankTable>();
auto &rank_items = rank_table.items();
int dst_num_rows =
std::lower_bound(rank_items.begin(), rank_items.end(), offset,
[](const framework::LoDRankTable::TableItem &a,
size_t b) { return a.length > b; }) -
rank_items.begin();
auto *out_var = scope.FindVar(Output("Out"));
PADDLE_ENFORCE(out_var != nullptr, "Output Out must be set");
auto &out_tensor = *out_var->GetMutable<framework::LoDTensor>();
if (dst_num_rows != 0) {
out_tensor.ShareDataWith(x_tensor.Slice(0, dst_num_rows));
}
}
};
class ShrinkRNNMemoryOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
ShrinkRNNMemoryOpProtoMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "");
AddInput("RankTable", "");
AddInput("I", "");
AddOutput("Out", "");
AddComment("");
}
};
class ShrinkRNNMemoryInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE(context->HasInput("X"));
PADDLE_ENFORCE(context->HasInput("I"));
PADDLE_ENFORCE(context->HasInput("RankTable"));
context->SetOutputDim("Out", context->GetInputDim("X"));
}
};
class ShrinkRNNMemoryGradOp : public ArrayOp {
public:
ShrinkRNNMemoryGradOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: ArrayOp(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
auto *dout_var = scope.FindVar(Input(framework::GradVarName("Out")));
auto *dx_var = scope.FindVar(Output(framework::GradVarName("X")));
PADDLE_ENFORCE(dx_var != nullptr, "Input Gradient should not be nullptr");
auto *x_var = scope.FindVar(Input("X"));
PADDLE_ENFORCE(x_var != nullptr);
auto &x_tensor = x_var->Get<framework::LoDTensor>();
auto &dx_tensor = *dx_var->GetMutable<framework::LoDTensor>();
dx_tensor.Resize(x_tensor.dims());
dx_tensor.mutable_data(x_tensor.place(), x_tensor.type());
if (dout_var == nullptr) { // dx_tensor fill zero
math::set_constant(dev_ctx, &dx_tensor, 0.0f);
} else {
auto &dout_tensor = dout_var->Get<framework::LoDTensor>();
auto height = dout_tensor.dims()[0];
dx_tensor.Slice(0, static_cast<int>(height))
.CopyFrom(dout_tensor, dout_tensor.place(), dev_ctx);
if (dx_tensor.dims()[0] < height) {
auto rest_tensor = dx_tensor.Slice(
static_cast<int>(height), static_cast<int>(dout_tensor.dims()[0]));
math::set_constant(dev_ctx, &rest_tensor, 0.0f);
}
}
}
};
class ShrinkRNNMemoryGradInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE(context->HasInput("X"));
PADDLE_ENFORCE(context->HasOutput(framework::GradVarName("X")));
context->SetOutputDim(framework::GradVarName("X"),
context->GetInputDim("X"));
}
};
class ShrinkRNNGradOpMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDescBind> Apply() const override {
auto *op = new framework::OpDescBind();
op->SetType("shrink_rnn_memory_grad");
op->SetInput("X", Input("X"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetAttrMap(Attrs());
return std::unique_ptr<framework::OpDescBind>(op);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(shrink_rnn_memory, ops::ShrinkRNNMemoryOp,
ops::ShrinkRNNMemoryInferShape,
ops::ShrinkRNNMemoryOpProtoMaker, ops::ShrinkRNNGradOpMaker);
REGISTER_OPERATOR(shrink_rnn_memory_grad, ops::ShrinkRNNMemoryGradOp,
ops::ShrinkRNNMemoryGradInferShape);
......@@ -11,48 +11,18 @@
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/lod_tensor_array.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/array_operator.h"
namespace paddle {
namespace operators {
class ArrayOpBase : public framework::OperatorBase {
public:
ArrayOpBase(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {}
protected:
size_t GetOffset(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const {
auto *i = scope.FindVar(Input("I"));
PADDLE_ENFORCE(i != nullptr, "I must be set");
auto &i_tensor = i->Get<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(i_tensor.numel(), 1);
size_t offset;
if (platform::is_gpu_place(i_tensor.place())) {
// FIXME: Avoid copy from GPU to CPU
framework::Tensor t;
t.CopyFrom(i_tensor, platform::CPUPlace(), dev_ctx);
dev_ctx.Wait();
offset = static_cast<size_t>(*t.data<int64_t>());
} else {
offset = static_cast<size_t>(*i_tensor.data<int64_t>());
}
return offset;
}
};
class WriteToArrayOp : public ArrayOpBase {
class WriteToArrayOp : public ArrayOp {
public:
WriteToArrayOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: ArrayOpBase(type, inputs, outputs, attrs) {}
: ArrayOp(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
......@@ -122,13 +92,13 @@ class WriteToArrayInferVarType : public framework::VarTypeInference {
}
};
class ReadFromArrayOp : public ArrayOpBase {
class ReadFromArrayOp : public ArrayOp {
public:
ReadFromArrayOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: ArrayOpBase(type, inputs, outputs, attrs) {}
: ArrayOp(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
auto *x = scope.FindVar(Input("X"));
......
......@@ -891,13 +891,13 @@ def zeros(shape, dtype, main_program=None):
def increment(x, value=1.0, main_program=None):
helper = LayerHelper("increment", **locals())
tmp = helper.create_tmp_variable(dtype=x.data_type)
out = helper.create_tmp_variable(dtype=x.data_type)
helper.append_op(
type='increment',
inputs={'X': [x]},
outputs={'Out': [tmp]},
outputs={'Out': [out]},
attrs={'step': value})
return tmp
return out
def array_write(x, i, array=None, main_program=None):
......@@ -928,3 +928,16 @@ def array_read(array, i, main_program=None):
'I': [i]},
outputs={'Out': [out]})
return out
def shrink_memory(x, i, table, main_program=None):
helper = LayerHelper('shrink_memory', **locals())
out = helper.create_tmp_variable(dtype=x.data_type)
helper.append_op(
type='shrink_rnn_memory',
inputs={'X': [x],
'I': [i],
'RankTable': [table]},
outputs={'Out': [out]},
attrs={})
return out
import unittest
import paddle.v2.framework.core as core
from paddle.v2.framework.executor import Executor
import paddle.v2.framework.layers as layers
from paddle.v2.framework.backward import append_backward_ops
from paddle.v2.framework.framework import g_main_program
import numpy
class TestShrinkRNNMemory(unittest.TestCase):
def test_shrink_rnn_memory(self):
x = layers.data('x', shape=[100], data_type='float32')
x.stop_gradient = False
table = layers.lod_rank_table(x=x)
i = layers.zeros(dtype='int64', shape=[1])
mem1 = layers.shrink_memory(x=x, i=i, table=table)
i = layers.increment(x=i)
i.stop_gradient = True
mem2 = layers.shrink_memory(x=mem1, i=i, table=table)
i = layers.increment(x=i)
i.stop_gradient = True
mem3 = layers.shrink_memory(x=mem2, i=i, table=table)
cpu = core.CPUPlace()
tensor = core.LoDTensor()
tensor.set_lod([[0, 2, 5, 6]])
tensor_np = numpy.random.random(size=(3, 100)).astype('float32')
tensor.set(tensor_np, cpu)
exe = Executor(cpu)
outs = map(numpy.array,
exe.run(feed={'x': tensor}, fetch_list=[mem1, mem2, mem3]))
self.assertTrue(numpy.allclose(tensor_np[0:3], outs[0]))
self.assertTrue(numpy.allclose(tensor_np[0:2], outs[1]))
self.assertTrue(numpy.allclose(tensor_np[0:1], outs[2]))
mem3_mean = layers.mean(x=mem3)
append_backward_ops(loss=mem3_mean)
x_grad = map(numpy.array,
exe.run(feed={'x': tensor},
fetch_list=[
g_main_program.global_block().var('x@GRAD')
]))[0]
self.assertAlmostEqual(1.0, x_grad.sum(), delta=0.1)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册