diff --git a/paddle/operators/reshape_op.cc b/paddle/operators/reshape_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..b7061153d2bf13982f14f233e87a87daeeebf5fd --- /dev/null +++ b/paddle/operators/reshape_op.cc @@ -0,0 +1,107 @@ + +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/reshape_op.h" + +namespace paddle { +namespace operators { + +class ReshapeOp : public framework::OperatorWithKernel { + public: + ReshapeOp(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorWithKernel(type, inputs, outputs, attrs) {} + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + // input check + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) shouldn't be null"); + auto shape = ctx.Attr>("shape"); + PADDLE_ENFORCE(shape.size() > 0, "Attr(shape) shouldn't be empty."); + for (auto dim : shape) { + PADDLE_ENFORCE(dim > 0, "Each dimension of shape must be positive."); + } + // capacity check + int64_t capacity = + std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); + auto *in = ctx.Input("X"); + int64_t in_size = framework::product(in->dims()); + PADDLE_ENFORCE_EQ(capacity, in_size, + "The size of Input(X) mismatches with Attr(shape)."); + // resize output + std::vector shape_int64(shape.size(), 0); + std::transform(shape.begin(), shape.end(), shape_int64.begin(), + [](int a) { return static_cast(a); }); + auto out_dims = framework::make_ddim(shape_int64); + ctx.Output("Out")->Resize(out_dims); + } +}; + +class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker { + public: + ReshapeOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "The input tensor of reshape operator."); + AddOutput("Out", "The output tensor of reshape operator."); + AddAttr>("shape", "Target shape of reshape operator."); + AddComment(R"DOC(Reshape operator + +Reshape Input(X) into the shape specified by Attr(shape). + +An example: +Given a 2-D tensor X with 2 rows and 2 columns + + [[1, 2], [3, 4]] + +with target shape = [1, 4], the reshape operator will transform +the tensor X into a 1-D tensor: + + [1, 2, 3, 4] + +)DOC"); + } +}; + +class ReshapeGradOp : public framework::OperatorWithKernel { + public: + ReshapeGradOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorWithKernel(type, inputs, outputs, attrs) {} + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) shouldn't be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), + "Input(Out@GRAD) shouldn't be null."); + auto dims = ctx.Input("X")->dims(); + auto *d_in = ctx.Output(framework::GradVarName("X")); + d_in->Resize(dims); + } +}; + +} // namespace operators +} // namespace paddle +namespace ops = paddle::operators; + +REGISTER_OP(reshape, ops::ReshapeOp, ops::ReshapeOpMaker, reshape_grad, + ops::ReshapeGradOp); +REGISTER_OP_CPU_KERNEL(reshape, + ops::ReshapeKernel); +REGISTER_OP_CPU_KERNEL( + reshape_grad, ops::ReshapeGradKernel); diff --git a/paddle/operators/reshape_op.cu b/paddle/operators/reshape_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..23dbe089d3b37aabedf9ef166f7bbfbf67da7e0a --- /dev/null +++ b/paddle/operators/reshape_op.cu @@ -0,0 +1,22 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/reshape_op.h" + +REGISTER_OP_GPU_KERNEL( + reshape, + paddle::operators::ReshapeKernel); +REGISTER_OP_GPU_KERNEL( + reshape_grad, + paddle::operators::ReshapeGradKernel); diff --git a/paddle/operators/reshape_op.h b/paddle/operators/reshape_op.h new file mode 100644 index 0000000000000000000000000000000000000000..26708e72dc8f80d2cff1c1ee5e8763b959320205 --- /dev/null +++ b/paddle/operators/reshape_op.h @@ -0,0 +1,56 @@ + +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class ReshapeKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const { + auto* out = ctx.Output("Out"); + auto* in = ctx.Input("X"); + out->mutable_data(ctx.GetPlace()); + + auto shape = ctx.Attr>("shape"); + std::vector shape_int64(shape.size(), 0); + std::transform(shape.begin(), shape.end(), shape_int64.begin(), + [](int a) { return static_cast(a); }); + auto out_dims = framework::make_ddim(shape_int64); + out->CopyFrom(*in, ctx.GetPlace()); + out->Resize(out_dims); + } +}; + +template +class ReshapeGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const { + auto* d_out = ctx.Input(framework::GradVarName("Out")); + auto* d_x = ctx.Output(framework::GradVarName("X")); + d_x->mutable_data(ctx.GetPlace()); + + auto in_dims = d_x->dims(); + d_x->CopyFrom(*d_out, ctx.GetPlace()); + d_x->Resize(in_dims); + } +}; +} +} diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 3958b53c22c383e5e2298bfdc4e8490d4148118f..16a2368aae5fff7445654161db4fd6a97d5bebfc 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -54,6 +54,7 @@ USE_CPU_ONLY_OP(concat); USE_OP(top_k); USE_OP(squared_l2_distance); USE_OP(sum); +USE_OP(reshape); namespace paddle { namespace framework { diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/framework/tests/CMakeLists.txt index 3de9e69e34d3d2be53b597d489323466a0fe4033..6b22c0008210b492d00dee42e967ca14d0948b20 100644 --- a/python/paddle/v2/framework/tests/CMakeLists.txt +++ b/python/paddle/v2/framework/tests/CMakeLists.txt @@ -35,3 +35,4 @@ py_test(test_sum_op SRCS test_sum_op.py) py_test(mnist SRCS mnist.py) py_test(test_concat_op SRCS test_concat_op.py) py_test(test_squared_l2_distance_op SRCS test_squared_l2_distance_op.py) +py_test(test_reshape_op SRCS test_reshape_op.py) diff --git a/python/paddle/v2/framework/tests/test_reshape_op.py b/python/paddle/v2/framework/tests/test_reshape_op.py new file mode 100644 index 0000000000000000000000000000000000000000..16bb6bb2af67f7d32a2fafc1cb37412084ec0829 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_reshape_op.py @@ -0,0 +1,21 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestReshapeOp(OpTest): + def setUp(self): + self.op_type = "reshape" + self.inputs = {'X': np.random.random((10, 20)).astype("float32")} + self.attrs = {'shape': [10 * 20]} + self.outputs = {'Out': self.inputs['X'].reshape(self.attrs['shape'])} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + +if __name__ == '__main__': + unittest.main()