diff --git a/doc/design/images/graph_construction_example.dot b/doc/design/images/graph_construction_example.dot index 8d1b673abf6b78c851676fa379dc850c4818f0e5..e115f9844bae6ad24f638c8ed4749cea8aff06a9 100644 --- a/doc/design/images/graph_construction_example.dot +++ b/doc/design/images/graph_construction_example.dot @@ -33,7 +33,6 @@ digraph ImageClassificationGraph { cost -> MSE_Grad [color=red]; d_cost -> MSE_Grad [color=red]; - x -> MSE_Grad [color=red]; l -> MSE_Grad [color=red]; y -> MSE_Grad -> d_y [color=red]; diff --git a/doc/design/images/graph_construction_example_all.png b/doc/design/images/graph_construction_example_all.png index 181187503472d15779b87284105841168b3945c4..261611a5721f9aa97874f7e6d897fe48cf667db2 100644 Binary files a/doc/design/images/graph_construction_example_all.png and b/doc/design/images/graph_construction_example_all.png differ diff --git a/doc/design/images/graph_construction_example_forward_backward.png b/doc/design/images/graph_construction_example_forward_backward.png index 3049a9315fd616464dec54e33064cb75598ca536..4c69687f4a6a181138f3df72ce5e8aa48487b5be 100644 Binary files a/doc/design/images/graph_construction_example_forward_backward.png and b/doc/design/images/graph_construction_example_forward_backward.png differ diff --git a/doc/design/images/graph_construction_example_forward_only.png b/doc/design/images/graph_construction_example_forward_only.png index 25d19088cbf0b5f68cf734f2ff21eba8af4a2860..e668c16e0cac73acb4e5dc2b1827557ae77126b4 100644 Binary files a/doc/design/images/graph_construction_example_forward_only.png and b/doc/design/images/graph_construction_example_forward_only.png differ diff --git a/doc/design/var_desc.md b/doc/design/var_desc.md index bfbbdd0578ebc69ea4b49ade9b041573a9e9ad55..0b2958c1b10ef6a6ce51aa75f61e15a7f2d94b3f 100644 --- a/doc/design/var_desc.md +++ b/doc/design/var_desc.md @@ -16,16 +16,23 @@ The computation graph is constructed by Data Node and Operation Node. The concep ## Definition of VarDesc -A VarDesc should have a name and value, in PaddlePaddle, the value will always be a tensor. Since we use LoDTensor most of the time. We add a LoDTesnorDesc to represent it. +A VarDesc should have a name, and value. The are two kinds of variable type in compile time, they are `LoDTensor` and `SelectedRows`. ```proto message VarDesc { required string name = 1; - optional LoDTensorDesc lod_tensor = 2; + enum VarType { + LOD_TENSOR = 0; + SELECTED_ROWS = 1; + } + required VarType type = 2; + optional LoDTensorDesc lod_desc = 3; + optional TensorDesc selected_rows_desc = 4; + optional bool persistable = 5 [ default = false ]; } ``` -## Definition of LodTensorDesc +## Definition of TensorDesc ```proto enum DataType { @@ -38,87 +45,25 @@ enum DataType { FP64 = 6; } -message LoDTensorDesc { +message TensorDesc { required DataType data_type = 1; - repeated int32 dims = 2; // [UNK, 640, 480] is saved as [-1, 640, 480] - optional int32 lod_level = 3 [default=0]; + repeated int64 dims = 2; // [UNK, 640, 480] is saved as [-1, 640, 480] } ``` -## Definition of Variable in Python - -In Python API, layer will take Variable as Input, and return Variable as Output. There should be a class `Variable` in python to help create and manage Variable. - -```python -image = Variable(dims=[-1, 640, 480]) -# fc1 and fc2 are both Variable -fc1 = layer.fc(input=image, output_size=10) -fc2 = layer.fc(input=fc1, output_size=20) -``` -### what should class `Variable` Have -1. `name`.a name of string type is used to mark the value of the Variable. -1. `initializer`. Since our Tensor does not have value. we will always use some Operator to fullfill it when run. So we should have a initialize method to help add the init operator. -1. `operator`. Variable should record which operator produce itself. The reaon is: - - we use pd.eval(targets=[var1, var2]) to run the related ops to get the value of var1 and var2. var.op is used to trace the dependency of the current variable. - -In PaddlePaddle, we use Block to describe Computation Graph, so in the code we will use Block but not Graph. - -```python -import VarDesc -import LoDTensorDesc -import framework - -def AddInitialOperator(variable, initializer): - # add an initialize Operator to block to init this Variable - -class Variable(object): - def __init__(self, name, dims, type, initializer): - self._block = get_default_block() - self._name = name - self.op = None - - tensor_desc = LoDTensorDesc(data_type=type, dims=dims) - _var_desc = VarDesc(name=name, lod_tensor=tensor_desc) - self._var = framework.CreateVar(_var_desc) - self._block.add_var(self) +A TensorDesc describes `SelectedRows` and `LoDTensor`. For details of `SelectedRows`, please reference [`SelectedRows`](./selected_rows.md). - # add initial op according to initializer - if initializer is not None: - AddInitialOperator(self, initializer) - - def dims(self): - return self._var.dims() - - def data_type(self): - return self._var.data_type() +## Definition of LodTensorDesc - def to_proto(self): - pass +```proto +message LoDTensorDesc { + required TensorDesc tensor = 1; + optional int lod_level = 2; +} ``` -Then we can use this Variable to create a fc layer in Python. +A LoDTensorDesc contains a tensor and a lod_level. -```python -import paddle as pd - -def flatten_size(X, num_flatten_dims): - prod = 1 # of last num_flatten_dims - for i in xrange(num_flatten_dims): - prod = prod * X.dims[-i-1] - return prod - -def layer.fc(X, output_size, num_flatten_dims): - W = Variable(pd.random_uniform(), type=FP32, dims=[flatten_size(X, num_flatten_dims), output_size]) - b = Variable(pd.random_uniform(), type=FP32, dims=[output_size]) - out = Variable(type=FP32) - y = operator.fc(X, W, b, output=out) # fc will put fc op input into out - pd.InferShape(y) - return out - -x = Variable(dims=[-1, 640, 480]) -y = layer.fc(x, output_size=100) -z = layer.fc(y, output_size=200) +## Definition of Variable in Python -paddle.eval(targets=[z], ...) -print(z) -``` +For Variable in Python, please reference [`Python API`](./python_api.md). diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 184ec65d3fa5526b9ec32b376f1a10ca8ca69a6d..14947b6f219ed53312834d2492d879033ddce401 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -19,10 +19,10 @@ cc_test(scope_test SRCS scope_test.cc DEPS scope) proto_library(framework_proto SRCS framework.proto) cc_library(attribute SRCS attribute.cc DEPS framework_proto) -cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS attribute ddim) +cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS attribute ddim op_info) cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute) cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker) -cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto proto_desc) +cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto) cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope proto_desc) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry) diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index 3b7cbcd98927be829d185590147adf74cd3d10d1..8fd8c826187e3e3f830fdb318c4edacbad8b7333 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -451,6 +451,7 @@ TEST(Backward, default_attribute) { op->SetInput("X", {"x"}); op->SetInput("Y", {"y"}); op->SetOutput("Out", {"out"}); + op->CheckAttrs(); AppendBackward(program, {}); diff --git a/paddle/framework/block_desc.cc b/paddle/framework/block_desc.cc index 509aa235d3ee226adef15f08f5785866700499f1..b77d5525d4508056c9d6d487e63e500265e1d700 100644 --- a/paddle/framework/block_desc.cc +++ b/paddle/framework/block_desc.cc @@ -91,9 +91,5 @@ BlockDescBind *BlockDescBind::ParentBlock() const { return prog_->Block(static_cast(this->desc_->parent_idx())); } -void OpDescBind::SetBlockAttr(const std::string &name, BlockDescBind &block) { - BlockDesc *desc = block.RawPtr(); - this->attrs_[name] = desc; -} } // namespace framework } // namespace paddle diff --git a/paddle/framework/executor_test.cc b/paddle/framework/executor_test.cc index 137e53d849542e48080228e0002931867c4d7fb2..eaa9c9414b631a986af1ec2de0ebad84ed27f983 100644 --- a/paddle/framework/executor_test.cc +++ b/paddle/framework/executor_test.cc @@ -59,6 +59,7 @@ void AddOp(const std::string& type, const VariableNameMap& inputs, op->SetOutput(kv.first, kv.second); } op->SetAttrMap(attrs); + op->CheckAttrs(); } // Tensors in feed value variable will only be in CPUPlace diff --git a/paddle/framework/op_desc.cc b/paddle/framework/op_desc.cc index d3c11ad60a0f9319329a59c16bfc4668cd75b7ae..a5d515bbca729220ca6df5fa07d02f1b3f025109 100644 --- a/paddle/framework/op_desc.cc +++ b/paddle/framework/op_desc.cc @@ -100,6 +100,12 @@ void OpDescBind::SetAttr(const std::string &name, const Attribute &v) { need_update_ = true; } +void OpDescBind::SetBlockAttr(const std::string &name, BlockDescBind &block) { + BlockDesc *desc = block.RawPtr(); + this->attrs_[name] = desc; + need_update_ = true; +} + void OpDescBind::SetAttrMap( const std::unordered_map &attr_map) { attrs_ = attr_map; diff --git a/paddle/framework/op_registry.cc b/paddle/framework/op_registry.cc index b118edae17430c8a4dd5c96a2a0c675766e08166..94f75b0f309417e52bb5ad3117797d8c25233257 100644 --- a/paddle/framework/op_registry.cc +++ b/paddle/framework/op_registry.cc @@ -62,11 +62,6 @@ std::unique_ptr OpRegistry::CreateOp(const OpDescBind& op_desc) { std::vector> OpRegistry::CreateGradOpDescs( OpDescBind* op_desc) { auto& info = OpInfoMap::Instance().Get(op_desc->Type()); - - if (info.Checker() != nullptr) { - info.Checker()->Check(*op_desc->MutableAttrMap()); - } - return info.grad_op_maker_(*op_desc); } diff --git a/paddle/framework/shape_inference.h b/paddle/framework/shape_inference.h index 64aab16ae54d34fd614add348c7c420b4a8f771d..b93f980cf6d279d18388b9637a2ff45d797ca78e 100644 --- a/paddle/framework/shape_inference.h +++ b/paddle/framework/shape_inference.h @@ -19,9 +19,6 @@ limitations under the License. */ namespace paddle { namespace framework { -// TODO(longfei): Once after both CompileTimeInferShapeContext and -// RuntimeInferShapeContext get merged, we can rename InferShapeContext into -// InferShapeContext so to replace the current InferShapeContext. class InferShapeContext { public: virtual ~InferShapeContext() {} diff --git a/paddle/framework/var_desc.h b/paddle/framework/var_desc.h index 464fece85fe5c674690c2034054e551f14db2138..44368795645664a343e2706fb670f104a42c5c9f 100644 --- a/paddle/framework/var_desc.h +++ b/paddle/framework/var_desc.h @@ -34,6 +34,7 @@ inline std::vector RepeatedToVector( template inline void VectorToRepeated(const std::vector &vec, RepeatedField *repeated_field) { + repeated_field->Clear(); repeated_field->Reserve(vec.size()); for (const auto &elem : vec) { *repeated_field->Add() = elem; @@ -44,6 +45,7 @@ inline void VectorToRepeated(const std::vector &vec, template inline void VectorToRepeated(const std::vector &vec, RepeatedField *repeated_field) { + repeated_field->Clear(); repeated_field->Reserve(vec.size()); for (auto elem : vec) { *repeated_field->Add() = elem; diff --git a/paddle/operators/adam_op.cc b/paddle/operators/adam_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..293b37b7750427cb88efb6dfd5a02dcf7ede24ac --- /dev/null +++ b/paddle/operators/adam_op.cc @@ -0,0 +1,144 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/adam_op.h" + +namespace paddle { +namespace operators { + +class AdamOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Param"), + "Input(Param) of AdamOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Grad"), + "Input(Grad) of AdamOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Moment1"), + "Input(Moment1) of AdamOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Moment2"), + "Input(Moment2) of AdamOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("LearningRate"), + "Input(LearningRate) of AdamOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Beta1Pow"), + "Input(Beta1Pow) of AdamOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Beta2Pow"), + "Input(Beta2Pow) of AdamOp should not be null."); + + PADDLE_ENFORCE(ctx->HasOutput("ParamOut"), + "Output(ParamOut) of AdamOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Moment1Out"), + "Output(Moment1Out) of AdamOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Moment2Out"), + "Output(Moment2Out) of AdamOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Beta1PowOut"), + "Output(Beta1PowOut) of AdamOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Beta2PowOut"), + "Output(Beta2PowOut) of AdamOp should not be null."); + + auto lr_dims = ctx->GetInputDim("LearningRate"); + PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1, + "Learning rate should have 1 dimension"); + auto beta1_pow_dims = ctx->GetInputDim("Beta1Pow"); + PADDLE_ENFORCE_EQ(framework::product(beta1_pow_dims), 1, + "Beta1 power accumulator should have 1 dimension"); + auto beta2_pow_dims = ctx->GetInputDim("Beta2Pow"); + PADDLE_ENFORCE_EQ(framework::product(beta1_pow_dims), 1, + "Beta1 power accumulator should have 1 dimension"); + + auto param_dims = ctx->GetInputDim("Param"); + PADDLE_ENFORCE_EQ( + param_dims, ctx->GetInputDim("Grad"), + "Param and Grad input of AdamOp should have same dimension"); + PADDLE_ENFORCE_EQ( + param_dims, ctx->GetInputDim("Moment1"), + "Param and Moment input of AdamOp should have same dimension"); + PADDLE_ENFORCE_EQ( + param_dims, ctx->GetInputDim("Moment2"), + "Param and InfNorm input of AdamOp should have same dimension"); + + ctx->SetOutputDim("ParamOut", param_dims); + ctx->SetOutputDim("Moment1Out", param_dims); + ctx->SetOutputDim("Moment2Out", param_dims); + ctx->SetOutputDim("Beta1PowOut", beta1_pow_dims); + ctx->SetOutputDim("Beta2PowOut", beta2_pow_dims); + } +}; + +class AdamOpMaker : public framework::OpProtoAndCheckerMaker { + public: + AdamOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("Param", "(Tensor) Input parameter"); + AddInput("Grad", "(Tensor) Input gradient"); + AddInput("LearningRate", "(Tensor) Learning rate"); + AddInput("Moment1", "(Tensor) Input first moment"); + AddInput("Moment2", "(Tensor) Input second moment"); + AddInput("Beta1Pow", "(Tensor) Input beta1 power accumulator"); + AddInput("Beta2Pow", "(Tensor) Input beta2 power accumulator"); + + AddOutput("ParamOut", "(Tensor) Output parameter"); + AddOutput("Moment1Out", "(Tensor) Output first moment"); + AddOutput("Moment2Out", "(Tensor) Output second moment"); + AddOutput("Beta1PowOut", "(Tensor) Output beta1 power accumulator"); + AddOutput("Beta2PowOut", "(Tensor) Output beta2 power accumulator"); + + AddAttr("beta1", + "(float, default 0.9) " + "Exponential decay rate for the " + "first moment estimates.") + .SetDefault(0.9f); + AddAttr("beta2", + "(float, default 0.999) " + "exponential decay rate for the " + "second moment estimates.") + .SetDefault(0.999f); + AddAttr("epsilon", + "(float, default 1.0e-8) " + "Constant for numerical stability") + .SetDefault(1.0e-8f); + + AddComment(R"DOC( +Adam Updates Operator. + +This implements the Adam optimizer from Section 2 of the Adam +paper[1]. Adam is a first-order gradient-based optimization +method based on adaptive estimates of lower-order moments. + +Adam updates: + +moment1_out = beta1 * moment1 + (1 − beta1) * grad +moment2_out = beta2 * moment2 + (1 − beta2) * grad * grad +beta1_pow_out = beta1_pow * beta1 +beta2_pow_out = beta2_pow * beta2 +learning_rate_t = learning_rate_t * + sqrt(1 - beta2_pow_out) / (1 - beta1_pow_out) +param_out = param - learning_rate_t * moment1/ (sqrt(moment2) + epsilon) + +References: + [1] Adam: A Method for Stochastic Optimization + (https://arxiv.org/abs/1412.6980) + +)DOC"); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(adam, ops::AdamOp, ops::AdamOpMaker); +REGISTER_OP_CPU_KERNEL(adam, + ops::AdamOpKernel); diff --git a/paddle/operators/adam_op.cu b/paddle/operators/adam_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..a3def912e540454275350209435eb01ae2151331 --- /dev/null +++ b/paddle/operators/adam_op.cu @@ -0,0 +1,20 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/operators/adam_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(adam, + ops::AdamOpKernel); diff --git a/paddle/operators/adam_op.h b/paddle/operators/adam_op.h new file mode 100644 index 0000000000000000000000000000000000000000..789c2f14b32478bf9ddc967fc5725bcf65ed2146 --- /dev/null +++ b/paddle/operators/adam_op.h @@ -0,0 +1,82 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class AdamOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto param_out_tensor = ctx.Output("ParamOut"); + auto moment1_out_tensor = ctx.Output("Moment1Out"); + auto moment2_out_tensor = ctx.Output("Moment2Out"); + auto beta1_pow_out_tensor = ctx.Output("Beta1PowOut"); + auto beta2_pow_out_tensor = ctx.Output("Beta2PowOut"); + + param_out_tensor->mutable_data(ctx.GetPlace()); + moment1_out_tensor->mutable_data(ctx.GetPlace()); + moment2_out_tensor->mutable_data(ctx.GetPlace()); + beta1_pow_out_tensor->mutable_data(ctx.GetPlace()); + beta2_pow_out_tensor->mutable_data(ctx.GetPlace()); + + float beta1 = ctx.Attr("beta1"); + float beta2 = ctx.Attr("beta2"); + float epsilon = ctx.Attr("epsilon"); + + auto param = framework::EigenVector::Flatten( + *ctx.Input("Param")); + auto grad = framework::EigenVector::Flatten( + *ctx.Input("Grad")); + auto moment1 = framework::EigenVector::Flatten( + *ctx.Input("Moment1")); + auto moment2 = framework::EigenVector::Flatten( + *ctx.Input("Moment2")); + auto lr = framework::EigenVector::Flatten( + *ctx.Input("LearningRate")); + auto beta1_pow = framework::EigenVector::Flatten( + *ctx.Input("Beta1Pow")); + auto beta2_pow = framework::EigenVector::Flatten( + *ctx.Input("Beta2Pow")); + auto param_out = framework::EigenVector::Flatten(*param_out_tensor); + auto moment1_out = framework::EigenVector::Flatten(*moment1_out_tensor); + auto moment2_out = framework::EigenVector::Flatten(*moment2_out_tensor); + auto beta1_pow_out = + framework::EigenVector::Flatten(*beta1_pow_out_tensor); + auto beta2_pow_out = + framework::EigenVector::Flatten(*beta2_pow_out_tensor); + auto place = ctx.GetEigenDevice(); + + moment1_out.device(place) = beta1 * moment1 + (1 - beta1) * grad; + moment2_out.device(place) = beta2 * moment2 + (1 - beta2) * grad.square(); + beta1_pow_out.device(place) = beta1_pow * beta1; + beta2_pow_out.device(place) = beta2_pow * beta2; + // All of these are tensors of 1 element + auto lr_t = lr * (1 - beta2_pow_out).sqrt() / (1 - beta1_pow_out); + // Eigen does not support automatic broadcast + // Get dimensions of moment vector to broadcast lr_t + Eigen::DSizes m_dsize(moment1_out_tensor->numel()); + param_out.device(place) = + param - + lr_t.broadcast(m_dsize) * + (moment1_out / (moment2_out.sqrt() + epsilon)); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/decayed_adagrad_op.cc b/paddle/operators/decayed_adagrad_op.cc index ca5141dabc2bb37bf60fa964ede88281bd23550a..7f583f18c8c6ee5025f6525306f9323fb329b030 100644 --- a/paddle/operators/decayed_adagrad_op.cc +++ b/paddle/operators/decayed_adagrad_op.cc @@ -22,7 +22,7 @@ class DecayedAdagradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase *ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasInput("Param"), "Input(Param) of DecayedAdagradOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("Grad"), diff --git a/paddle/operators/uniform_random_op.cc b/paddle/operators/uniform_random_op.cc index e330877fc4283b796dcb5c5d745881884ae491ae..75928f1ec818ab028ea06cfa72273fb99430c3c8 100644 --- a/paddle/operators/uniform_random_op.cc +++ b/paddle/operators/uniform_random_op.cc @@ -54,7 +54,7 @@ class UniformRandomOp : public framework::OperatorWithKernel { PADDLE_ENFORCE( ctx->Attrs().Get("min") < ctx->Attrs().Get("max"), "uniform_random's min must less then max"); - auto dims = Attr>("dims"); + auto& dims = ctx->Attrs().Get>("dims"); std::vector temp; temp.reserve(dims.size()); for (auto dim : dims) { diff --git a/paddle/pybind/protobuf.cc b/paddle/pybind/protobuf.cc index 0e4bbe8415fd86ab29c6809e7652dc581b4e6004..7ab4e6a451846199d249ee8c6cf24483802a58da 100644 --- a/paddle/pybind/protobuf.cc +++ b/paddle/pybind/protobuf.cc @@ -204,7 +204,7 @@ void BindOpDesc(py::module &m) { .def("set_attr", &OpDescBind::SetAttr) .def("attr", &OpDescBind::GetAttr) .def("set_block_attr", &OpDescBind::SetBlockAttr) - .def("get_block_attr", &OpDescBind::GetBlockAttr) + .def("block_attr", &OpDescBind::GetBlockAttr) .def("check_attrs", &OpDescBind::CheckAttrs) .def("infer_shape", &OpDescBind::InferShape); } diff --git a/python/paddle/v2/framework/graph.py b/python/paddle/v2/framework/framework.py similarity index 66% rename from python/paddle/v2/framework/graph.py rename to python/paddle/v2/framework/framework.py index 0f0a2847e58a1ca172bf1ba382abb2ebc1ecb8ed..2afbd0c83158d583dd637cceeb35321ddb68f323 100644 --- a/python/paddle/v2/framework/graph.py +++ b/python/paddle/v2/framework/framework.py @@ -1,4 +1,5 @@ import paddle.v2.framework.core as core +import paddle.v2.framework.proto.framework_pb2 as framework_pb2 import collections import numpy as np import copy @@ -106,6 +107,40 @@ class Variable(object): raise ValueError("Not supported numpy dtype " + str(dtype)) +def get_all_op_protos(): + """ + Get all registered op proto from PaddlePaddle C++ end. + :return: A list of registered OpProto. + """ + protostrs = core.get_all_op_protos() + ret_values = [] + for pbstr in protostrs: + op_proto = framework_pb2.OpProto.FromString(str(pbstr)) + ret_values.append(op_proto) + return ret_values + + +class OpProtoHolder(object): + @classmethod + def instance(cls): + if not hasattr(cls, '_instance'): + cls._instance = cls() + return cls._instance + + def __init__(self): + assert not hasattr( + self.__class__, + '_instance'), 'Please use `instance()` to get OpProtoHolder opject!' + op_protos = get_all_op_protos() + self.op_proto_map = {} + for proto in op_protos: + self.op_proto_map[proto.type] = proto + + def get_op_proto(self, type): + assert type in self.op_proto_map, "Operator \"%s\" has not been registered." % type + return self.op_proto_map[type] + + class Operator(object): def __init__(self, block, @@ -116,20 +151,89 @@ class Operator(object): attrs=None): self.block = block self.desc = desc - if type is not None: - # TODO. - pass + if len(self.desc.type()) != 0: + return + if type is None: + raise ValueError( + "`type` to initilized an Operator can not be None.") + self.desc.set_type(type) + proto = OpProtoHolder.instance().get_op_proto(type) + if inputs is not None: - # TODO - pass + for in_proto in proto.inputs: + in_argus = inputs[in_proto.name] + if not isinstance(in_argus, list): + in_argus = [in_argus] + if not in_proto.duplicable and len(in_argus) > 1: + raise ValueError( + "Input %s expects only one input, but %d are given." % + (in_proto.name, len(in_argus))) + in_argu_names = [] + for argu in in_argus: + in_argu_names.append(argu.name) + self.desc.set_input(in_proto.name, in_argu_names) + if outputs is not None: - # TODO - pass + for out_proto in proto.outputs: + out_argus = outputs[out_proto.name] + if not isinstance(out_argus, list): + out_argus = [out_argus] + if not out_proto.duplicable and len(out_argus) > 1: + raise ValueError( + "Output %s expects only one output, but %d are given." % + (out_proto.name, len(out_argus))) + out_argu_names = [] + for argu in out_argus: + out_argu_names.append(argu.name) + argu.op = self + self.desc.set_output(out_proto.name, out_argu_names) + if attrs is not None: - # TODO - pass + for attr in proto.attrs: + attr_name = attr.name + if not attr_name in attrs: + continue + if not isinstance(attrs[attr_name], Block): + self.desc.set_attr(attr_name, attrs[attr_name]) + else: + self.desc.set_block_attr(attr_name, attrs[attr_name].desc) + + self.desc.check_attrs() + self.desc.infer_shape(self.block.desc) + + @property + def type(self): + return self.desc.type() + + def input(self, name): + return self.desc.input(name) + + @property + def input_names(self): + return self.desc.input_names() + + def output(self, name): + return self.desc.output(name) + + @property + def output_names(self): + return self.desc.output_names() + + def has_attr(self, name): + return self.desc.has_attr(name) + + def attr_type(self, name): + return self.desc.attr_type(name) + + @property + def attr_names(self): + return self.desc.attr_names() + + def attr(self, name): + return self.desc.attr(name) - # TODO: Getters + def block_attr(self, name): + return self.desc.block_attr(name) class Block(object): diff --git a/python/paddle/v2/framework/tests/test_adam_op.py b/python/paddle/v2/framework/tests/test_adam_op.py new file mode 100644 index 0000000000000000000000000000000000000000..ff6faafa6e2119fde11b9eb6cd2a65a75334ebe6 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_adam_op.py @@ -0,0 +1,186 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestAdamOp1(OpTest): + def setUp(self): + '''Test Adam Op with supplied attributes + ''' + self.op_type = "adam" + param = np.random.uniform(-1, 1, (102, 105)).astype("float32") + grad = np.random.uniform(-1, 1, (102, 105)).astype("float32") + moment1 = np.random.uniform(-1, 1, (102, 105)).astype("float32") + # The second moment is positive + moment2 = np.random.random((102, 105)).astype("float32") + + learning_rate = 0.004 + beta1 = 0.78 + beta2 = 0.836 + epsilon = 1e-4 + beta1_pow = beta1**10 + beta2_pow = beta2**10 + + self.inputs = { + 'Param': param, + 'Grad': grad, + 'Moment1': moment1, + 'Moment2': moment2, + 'LearningRate': np.array([learning_rate]).astype("float32"), + 'Beta1Pow': np.array([beta1_pow]).astype("float32"), + 'Beta2Pow': np.array([beta2_pow]).astype("float32") + } + + self.attrs = {'epsilon': epsilon, 'beta1': beta1, 'beta2': beta2} + + param_out, moment1_out, moment2_out, beta1_pow_out, \ + beta2_pow_out = adam_step(self.inputs, self.attrs) + + self.outputs = { + 'Moment1Out': moment1_out, + 'Moment2Out': moment2_out, + 'Beta1PowOut': beta1_pow_out, + 'Beta2PowOut': beta2_pow_out, + 'ParamOut': param_out + } + + def test_check_output(self): + self.check_output() + + +class TestAdamOp2(OpTest): + def setUp(self): + '''Test Adam Op with supplied attributes + ''' + self.op_type = "adam" + param = np.random.uniform(-1, 1, (102, 105)).astype("float32") + grad = np.random.uniform(-1, 1, (102, 105)).astype("float32") + moment1 = np.random.uniform(-1, 1, (102, 105)).astype("float32") + # The second moment is positive + moment2 = np.random.random((102, 105)).astype("float32") + + learning_rate = 0.001 + beta1 = 0.9 + beta2 = 0.999 + epsilon = 1e-8 + beta1_pow = beta1**10 + beta2_pow = beta2**10 + + self.inputs = { + 'Param': param, + 'Grad': grad, + 'Moment1': moment1, + 'Moment2': moment2, + 'LearningRate': np.array([learning_rate]).astype("float32"), + 'Beta1Pow': np.array([beta1_pow]).astype("float32"), + 'Beta2Pow': np.array([beta2_pow]).astype("float32") + } + + attributes = {'epsilon': epsilon, 'beta1': beta1, 'beta2': beta2} + + param_out, moment1_out, moment2_out, beta1_pow_out, \ + beta2_pow_out = adam_step(self.inputs, attributes) + + self.outputs = { + 'Moment1Out': moment1_out, + 'Moment2Out': moment2_out, + 'Beta1PowOut': beta1_pow_out, + 'Beta2PowOut': beta2_pow_out, + 'ParamOut': param_out + } + + def test_check_output(self): + self.check_output() + + +class TestAdamOpMultipleSteps(OpTest): + def setUp(self): + '''Test Adam Operator with supplied attributes + ''' + self.op_type = "adam" + self.num_steps = 10 + + param = np.random.uniform(-1, 1, (102, 105)).astype("float32") + grad = np.random.uniform(-1, 1, (102, 105)).astype("float32") + moment1 = np.random.uniform(-1, 1, (102, 105)).astype("float32") + # The second moment is positive + moment2 = np.random.random((102, 105)).astype("float32") + + learning_rate = 0.001 + beta1 = 0.9 + beta2 = 0.999 + epsilon = 1e-8 + beta1_pow = beta1**10 + beta2_pow = beta2**10 + + self.inputs = { + 'Param': param, + 'Grad': grad, + 'Moment1': moment1, + 'Moment2': moment2, + 'LearningRate': np.array([learning_rate]).astype("float32"), + 'Beta1Pow': np.array([beta1_pow]).astype("float32"), + 'Beta2Pow': np.array([beta2_pow]).astype("float32") + } + + self.attrs = {'epsilon': epsilon, 'beta1': beta1, 'beta2': beta2} + + def test_check_output(self): + for _ in range(self.num_steps): + param_out, moment1_out, moment2_out, beta1_pow_out, \ + beta2_pow_out = adam_step(self.inputs, self.attrs) + + self.outputs = { + 'Moment1Out': moment1_out, + 'Moment2Out': moment2_out, + 'Beta1PowOut': beta1_pow_out, + 'Beta2PowOut': beta2_pow_out, + 'ParamOut': param_out + } + + # Verify output for this step + self.check_output() + + # Output of this step becomes input for next step + self.inputs['Param'] = param_out + self.inputs['Moment1'] = moment1_out + self.inputs['Moment2'] = moment2_out + self.inputs['Beta1Pow'] = beta1_pow_out + self.inputs['Beta2Pow'] = beta2_pow_out + + # Randomize gradient for next step + self.inputs['Grad'] = np.random.uniform( + -1, 1, (102, 105)).astype("float32") + + +def adam_step(inputs, attributes): + ''' + Simulate one step of the adam optimizer + :param inputs: dict of inputs + :param attributes: dict of attributes + :return tuple: tuple of output param, moment1, moment2, + beta1 power accumulator and beta2 power accumulator + ''' + param = inputs['Param'] + grad = inputs['Grad'] + moment1 = inputs['Moment1'] + moment2 = inputs['Moment2'] + lr = inputs['LearningRate'] + beta1_pow = inputs['Beta1Pow'] + beta2_pow = inputs['Beta2Pow'] + + beta1 = attributes['beta1'] + beta2 = attributes['beta2'] + epsilon = attributes['epsilon'] + + moment1_out = beta1 * moment1 + (1 - beta1) * grad + moment2_out = beta2 * moment2 + (1 - beta2) * np.square(grad) + beta1_pow_out = beta1_pow * beta1 + beta2_pow_out = beta2_pow * beta2 + lr_t = lr * np.sqrt(1 - beta2_pow_out) / (1 - beta1_pow_out) + param_out = param - lr_t * (moment1_out / (np.sqrt(moment2_out) + epsilon)) + return param_out, moment1_out, moment2_out, beta1_pow_out, beta2_pow_out + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_conv2d_op.py b/python/paddle/v2/framework/tests/test_conv2d_op.py index bfbb213d75df73250182b38407aa1dd17297d450..2fb808944ac97f2bdcb05336a2205346ded65a4d 100644 --- a/python/paddle/v2/framework/tests/test_conv2d_op.py +++ b/python/paddle/v2/framework/tests/test_conv2d_op.py @@ -3,71 +3,56 @@ import numpy as np from op_test import OpTest +def conv2d_forward_naive(input, filter, group, conv_param): + in_n, in_c, in_h, in_w = input.shape + out_c, f_c, f_h, f_w = filter.shape + assert f_c * group == in_c + assert np.mod(out_c, group) == 0 + sub_out_c = out_c / group + + stride, pad = conv_param['stride'], conv_param['pad'] + out_h = 1 + (in_h + 2 * pad[0] - f_h) / stride[0] + out_w = 1 + (in_w + 2 * pad[1] - f_w) / stride[1] + out = np.zeros((in_n, out_c, out_h, out_w)) + + input_pad = np.pad(input, ((0, ), (0, ), (pad[0], ), (pad[1], )), + mode='constant', + constant_values=0) + for i in range(out_h): + for j in range(out_w): + for g in range(group): + input_pad_masked = \ + input_pad[:, g * f_c:(g + 1) * f_c, + i * stride[0]:i * stride[0] + f_h, + j * stride[1]:j * stride[1] + f_w] + + f_sub = filter[g * sub_out_c:(g + 1) * sub_out_c, :, :, :] + for k in range(sub_out_c): + out[:, g * sub_out_c + k, i, j] = \ + np.sum(input_pad_masked * f_sub[k, :, :, :], + axis=(1, 2, 3)) + + return out + + class TestConv2dOp(OpTest): def setUp(self): - self.init_groups() - self.init_optype() - batch_size = 2 - input_channels = 3 - input_height = 5 - input_width = 5 - output_channels = 6 - filter_height = 3 - filter_width = 3 - stride = 1 - padding = 0 - output_height = (input_height - filter_height + 2 * padding - ) / stride + 1 - output_width = (input_width - filter_width + 2 * padding) / stride + 1 - input = np.random.random((batch_size, input_channels, input_height, - input_width)).astype("float32") - - filter = np.random.random( - (output_channels, input_channels / self.groups, filter_height, - filter_width)).astype("float32") - output = np.ndarray( - (batch_size, output_channels, output_height, output_width)) + self.init_op_type() + self.init_group() + self.init_test_case() + + conv2d_param = {'stride': self.stride, 'pad': self.pad} + input = np.random.random(self.input_size).astype("float32") + filter = np.random.random(self.filter_size).astype("float32") + output = conv2d_forward_naive(input, filter, self.groups, conv2d_param) self.inputs = {'Input': input, 'Filter': filter} self.attrs = { - 'strides': [1, 1], - 'paddings': [0, 0], - 'dilations': [1, 1], - 'groups': self.groups + 'strides': self.stride, + 'paddings': self.pad, + 'groups': self.groups, + 'dilations': self.dilations } - - output_group_channels = output_channels / self.groups - input_group_channels = input_channels / self.groups - for batchid in xrange(batch_size): - for group in xrange(self.groups): - for outchannelid in range(group * output_group_channels, - (group + 1) * output_group_channels): - for rowid in xrange(output_height): - for colid in xrange(output_width): - start_h = (rowid * stride) - padding - start_w = (colid * stride) - padding - output_value = 0.0 - for inchannelid in range( - group * input_group_channels, - (group + 1) * input_group_channels): - for frowid in xrange(filter_height): - for fcolid in xrange(filter_width): - input_value = 0.0 - inrowid = start_h + frowid - incolid = start_w + fcolid - if ((inrowid >= 0 and - inrowid < input_height) and - (incolid >= 0 and - incolid < input_width)): - input_value = input[batchid][ - inchannelid][inrowid][incolid] - filter_value = filter[outchannelid][ - inchannelid % input_group_channels][ - frowid][fcolid] - output_value += input_value * filter_value - output[batchid][outchannelid][rowid][ - colid] = output_value - self.outputs = {'Output': output} def test_check_output(self): @@ -91,30 +76,47 @@ class TestConv2dOp(OpTest): max_relative_error=0.05, no_grad_set=set(['Input'])) - def init_groups(self): + def init_test_case(self): + # self.groups = 1 + # self.op_type = "conv2d" + self.pad = [0, 0] + self.stride = [1, 1] + self.dilations = [1, 1] + self.input_size = [2, 3, 5, 5] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] / self.groups + self.filter_size = [6, f_c, 3, 3] + + def init_group(self): self.groups = 1 - def init_optype(self): + def init_op_type(self): self.op_type = "conv2d" class TestWithGroup(TestConv2dOp): - def init_groups(self): + def init_group(self): self.groups = 3 + def init_op_type(self): + self.op_type = "conv2d" -class TestCudnn2d(TestConv2dOp): - def init_optype(self): - self.op_type = "conv_cudnn" +class TestCudnn(TestConv2dOp): + def init_group(self): + self.groups = 1 -class TestCudnn2dWithGroup(TestConv2dOp): - def init_optype(self): + def init_op_type(self): self.op_type = "conv_cudnn" - def init_groups(self): + +class TestCudnnWithGroup(TestConv2dOp): + def init_group(self): self.groups = 3 + def init_op_type(self): + self.op_type = "conv_cudnn" + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/v2/framework/tests/test_operator_desc.py b/python/paddle/v2/framework/tests/test_operator_desc.py new file mode 100644 index 0000000000000000000000000000000000000000..d7a85d8e4e883efd268c53a0e4977533040a0a14 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_operator_desc.py @@ -0,0 +1,76 @@ +import unittest +from paddle.v2.framework.framework import Variable, g_program +import paddle.v2.framework.core as core + + +class TestOperator(unittest.TestCase): + def test_error_type(self): + block = g_program.create_block() + try: + block.append_op() + self.assertFail() + except ValueError as v_err: + self.assertEqual( + v_err.message, + "`type` to initilized an Operator can not be None.") + try: + block.append_op(type="no_such_op") + self.assertFail() + except AssertionError as a_err: + self.assertEqual(a_err.message, + "Operator \"no_such_op\" has not been registered.") + + def test_op_desc_creation(self): + block = g_program.current_block() + mul_x = block.create_var( + dtype="float32", shape=[5, 10], lod_level=0, name="mul.x") + mul_y = block.create_var( + dtype="float32", shape=[10, 8], lod_level=0, name="mul.y") + mul_out = block.create_var( + dtype="float32", shape=[5, 8], lod_level=0, name="mul.out") + mul_op = block.append_op( + type="mul", + inputs={"X": [mul_x], + "Y": mul_y}, + outputs={"Out": [mul_out]}, + attrs={"x_num_col_dims": 1}) + self.assertEqual(mul_op.type, "mul") + self.assertEqual(mul_op.input_names, ["X", "Y"]) + self.assertEqual(mul_op.input("X"), ["mul.x"]) + self.assertEqual(mul_op.input("Y"), ["mul.y"]) + self.assertEqual(mul_op.output_names, ["Out"]) + self.assertEqual(mul_op.output("Out"), ["mul.out"]) + self.assertEqual( + set(mul_op.attr_names), set(["x_num_col_dims", "y_num_col_dims"])) + self.assertEqual(mul_op.has_attr("x_num_col_dims"), True) + self.assertEqual(mul_op.attr_type("x_num_col_dims"), core.AttrType.INT) + self.assertEqual(mul_op.attr("x_num_col_dims"), 1) + self.assertEqual(mul_op.has_attr("y_num_col_dims"), True) + self.assertEqual(mul_op.attr_type("y_num_col_dims"), core.AttrType.INT) + self.assertEqual(mul_op.attr("y_num_col_dims"), 1) + self.assertEqual(mul_out.op, mul_op) + + def test_mult_input(self): + block = g_program.current_block() + sum_x1 = block.create_var( + dtype="int", shape=[3, 4], lod_level=0, name="sum.x1") + sum_x2 = block.create_var( + dtype="int", shape=[3, 4], lod_level=0, name="sum.x2") + sum_x3 = block.create_var( + dtype="int", shape=[3, 4], lod_level=0, name="sum.x3") + sum_out = block.create_var( + dtype="int", shape=[3, 4], lod_level=0, name="sum.out") + sum_op = block.append_op( + type="sum", + inputs={"X": [sum_x1, sum_x2, sum_x3]}, + outputs={"Out": sum_out}) + self.assertEqual(sum_op.type, "sum") + self.assertEqual(sum_op.input_names, ["X"]) + self.assertEqual(sum_op.input("X"), ["sum.x1", "sum.x2", "sum.x3"]) + self.assertEqual(sum_op.output_names, ["Out"]) + self.assertEqual(sum_op.output("Out"), ["sum.out"]) + self.assertEqual(sum_out.op, sum_op) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_parameter.py b/python/paddle/v2/framework/tests/test_parameter.py index 3b5d38f257e6f51be30d9f1fa42285461b2a0eb7..1ac0cdd99f1b7c15d64ae9d2c465d5a9d563bd80 100644 --- a/python/paddle/v2/framework/tests/test_parameter.py +++ b/python/paddle/v2/framework/tests/test_parameter.py @@ -1,5 +1,5 @@ import unittest -from paddle.v2.framework.graph import g_program +from paddle.v2.framework.framework import g_program import paddle.v2.framework.core as core diff --git a/python/paddle/v2/framework/tests/test_program.py b/python/paddle/v2/framework/tests/test_program.py index 83e184494ad235f6493a7ea8e25886b1e35004ee..64b781e6ea21bff90646d312a157d60852f276df 100644 --- a/python/paddle/v2/framework/tests/test_program.py +++ b/python/paddle/v2/framework/tests/test_program.py @@ -1,7 +1,7 @@ import unittest import paddle.v2.framework.core as core -from paddle.v2.framework.graph import g_program +from paddle.v2.framework.framework import g_program class TestProgram(unittest.TestCase): diff --git a/python/paddle/v2/framework/tests/test_protobuf_descs.py b/python/paddle/v2/framework/tests/test_protobuf_descs.py index 3db1e79ce43b7f559c7caab8397817b76d56161e..af5ed6801fa7b87e9193df78c7d28cf637eafa42 100644 --- a/python/paddle/v2/framework/tests/test_protobuf_descs.py +++ b/python/paddle/v2/framework/tests/test_protobuf_descs.py @@ -53,7 +53,7 @@ class TestOpDesc(unittest.TestCase): self.assertEqual(8, len(op.attr_names())) op.set_block_attr("block_attr", prog.block(0)) - self.assertEqual(0, op.get_block_attr("block_attr")) + self.assertEqual(0, op.block_attr("block_attr")) mul_op = block.append_op() mul_op.set_type("mul") diff --git a/python/paddle/v2/framework/tests/test_variable.py b/python/paddle/v2/framework/tests/test_variable.py index 8ea1083ff6535d2d517f2ac587a956bfed906f03..695aaaee6c0c1d035349b1d1716c24bab81e607b 100644 --- a/python/paddle/v2/framework/tests/test_variable.py +++ b/python/paddle/v2/framework/tests/test_variable.py @@ -1,5 +1,5 @@ import unittest -from paddle.v2.framework.graph import Variable, g_program +from paddle.v2.framework.framework import Variable, g_program import paddle.v2.framework.core as core import numpy as np