提交 6d41bfb7 编写于 作者: Y Yang Yu

Add increment op

上级 568270f3
...@@ -12,22 +12,57 @@ ...@@ -12,22 +12,57 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/operators/increment_op.h" #include "paddle/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class IncrementOp : public framework::OperatorWithKernel { class IncrementInferShape : public framework::InferShapeBase {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; void operator()(framework::InferShapeContext *ctx) const override {
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of IncrementOp should not be null."); "Input(X) of IncrementOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of IncrementOp should not be null."); "Output(Out) of IncrementOp should not be null.");
PADDLE_ENFORCE_EQ(1, framework::product(ctx->GetInputDim("X")));
ctx->SetOutputDim("Out", ctx->GetInputDim("X")); ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ "Out"); }
};
struct IncrementFunctor {
IncrementFunctor(const framework::LoDTensor &x, framework::LoDTensor *out,
float value)
: x_(x), out_(out), value_(value) {}
template <typename T>
void operator()() const {
*out_->data<T>() = *x_.data<T>() + static_cast<T>(value_);
}
const framework::LoDTensor &x_;
framework::LoDTensor *out_;
float value_;
};
class IncrementOp : public framework::OperatorBase {
public:
IncrementOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
auto &x = scope.FindVar(Input("X"))->Get<framework::LoDTensor>();
auto &out =
*scope.FindVar(Output("Out"))->GetMutable<framework::LoDTensor>();
PADDLE_ENFORCE(platform::is_cpu_place(x.place()));
out.Resize(x.dims());
out.mutable_data(x.place(), x.type());
float value = Attr<float>("step");
framework::VisitDataType(framework::ToDataType(out.type()),
IncrementFunctor(x, &out, value));
} }
}; };
...@@ -59,10 +94,10 @@ class IncrementGradOpMaker : public framework::SingleGradOpDescMaker { ...@@ -59,10 +94,10 @@ class IncrementGradOpMaker : public framework::SingleGradOpDescMaker {
std::unique_ptr<framework::OpDescBind> Apply() const override { std::unique_ptr<framework::OpDescBind> Apply() const override {
auto *grad_op = new framework::OpDescBind(); auto *grad_op = new framework::OpDescBind();
grad_op->SetType("scale"); grad_op->SetType("increment");
grad_op->SetInput("X", OutputGrad("Out")); grad_op->SetInput("X", Output("Out"));
grad_op->SetOutput("Out", InputGrad("X")); grad_op->SetOutput("Out", Input("X"));
grad_op->SetAttr("scale", 1.0f); grad_op->SetAttr("step", -boost::get<float>(GetAttr("step")));
return std::unique_ptr<framework::OpDescBind>(grad_op); return std::unique_ptr<framework::OpDescBind>(grad_op);
} }
}; };
...@@ -71,11 +106,5 @@ class IncrementGradOpMaker : public framework::SingleGradOpDescMaker { ...@@ -71,11 +106,5 @@ class IncrementGradOpMaker : public framework::SingleGradOpDescMaker {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(increment, ops::IncrementOp, ops::IncrementInferShape,
REGISTER_OPERATOR(increment, ops::IncrementOp, ops::IncrementOpMaker, ops::IncrementOpMaker, ops::IncrementGradOpMaker);
ops::IncrementGradOpMaker);
REGISTER_OP_CPU_KERNEL(
increment, ops::IncrementKernel<paddle::platform::CPUPlace, float>,
ops::IncrementKernel<paddle::platform::CPUPlace, double>,
ops::IncrementKernel<paddle::platform::CPUPlace, int>,
ops::IncrementKernel<paddle::platform::CPUPlace, int64_t>);
...@@ -800,7 +800,7 @@ def array_to_lod_tensor(x, table, main_program=None): ...@@ -800,7 +800,7 @@ def array_to_lod_tensor(x, table, main_program=None):
def fill_constant(shape, dtype, value, main_program=None): def fill_constant(shape, dtype, value, main_program=None):
helper = LayerHelper("ones", **locals()) helper = LayerHelper("fill_constant", **locals())
out = helper.create_tmp_variable(dtype=dtype) out = helper.create_tmp_variable(dtype=dtype)
helper.append_op( helper.append_op(
type='fill_constant', type='fill_constant',
...@@ -823,9 +823,12 @@ def zeros(shape, dtype, main_program=None): ...@@ -823,9 +823,12 @@ def zeros(shape, dtype, main_program=None):
return fill_constant(value=0.0, **locals()) return fill_constant(value=0.0, **locals())
def increment(x, value=1.0, main_program=None): def increment(x, value=1.0, in_place=False, main_program=None):
helper = LayerHelper("increment", **locals()) helper = LayerHelper("increment", **locals())
tmp = helper.create_tmp_variable(dtype=x.data_type) if in_place:
tmp = x
else:
tmp = helper.create_tmp_variable(dtype=x.data_type)
helper.append_op( helper.append_op(
type='increment', type='increment',
inputs={'X': [x]}, inputs={'X': [x]},
......
...@@ -20,21 +20,19 @@ class TestArrayReadWrite(unittest.TestCase): ...@@ -20,21 +20,19 @@ class TestArrayReadWrite(unittest.TestCase):
each_x.stop_gradient = False each_x.stop_gradient = False
i = layers.zeros(shape=[1], dtype='int64') i = layers.zeros(shape=[1], dtype='int64')
i.stop_gradient = False
arr = layers.array_write(x=x[0], i=i) arr = layers.array_write(x=x[0], i=i)
i = layers.increment(x=i) i = layers.increment(x=i)
i.stop_gradient = True
arr = layers.array_write(x=x[1], i=i, array=arr) arr = layers.array_write(x=x[1], i=i, array=arr)
i = layers.increment(x=i) i = layers.increment(x=i)
i.stop_gradient = True
arr = layers.array_write(x=x[2], i=i, array=arr) arr = layers.array_write(x=x[2], i=i, array=arr)
i = layers.zeros(shape=[1], dtype='int64') i = layers.zeros(shape=[1], dtype='int64')
i.stop_gradient = False
a0 = layers.array_read(array=arr, i=i) a0 = layers.array_read(array=arr, i=i)
i = layers.increment(x=i) i = layers.increment(x=i)
i.stop_gradient = True # index should not calculate gradient
a1 = layers.array_read(array=arr, i=i) a1 = layers.array_read(array=arr, i=i)
i = layers.increment(x=i) i = layers.increment(x=i)
i.stop_gradient = True
a2 = layers.array_read(array=arr, i=i) a2 = layers.array_read(array=arr, i=i)
mean_a0 = layers.mean(x=a0) mean_a0 = layers.mean(x=a0)
......
import unittest
import numpy as np
from op_test import OpTest
class TestIncrementOpPositiveStep(OpTest):
"""Test increment op with positive step
"""
def setUp(self):
self.op_type = "increment"
self.inputs = {'X': np.random.random((10, 10)).astype("float32")}
self.attrs = {'step': 14.8}
self.outputs = {'Out': self.inputs['X'] + self.attrs['step']}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
class TestIncrementOpNegativeStep(OpTest):
"""Test increment op with negative step
"""
def setUp(self):
self.op_type = "increment"
self.inputs = {'X': np.random.random((10, 10)).astype("float32")}
self.attrs = {'step': -3.8}
self.outputs = {'Out': self.inputs['X'] + self.attrs['step']}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册