diff --git a/doc/design/images/graph_construction_example.dot b/doc/design/images/graph_construction_example.dot index 8d1b673abf6b78c851676fa379dc850c4818f0e5..e115f9844bae6ad24f638c8ed4749cea8aff06a9 100644 --- a/doc/design/images/graph_construction_example.dot +++ b/doc/design/images/graph_construction_example.dot @@ -33,7 +33,6 @@ digraph ImageClassificationGraph { cost -> MSE_Grad [color=red]; d_cost -> MSE_Grad [color=red]; - x -> MSE_Grad [color=red]; l -> MSE_Grad [color=red]; y -> MSE_Grad -> d_y [color=red]; diff --git a/doc/design/images/graph_construction_example_all.png b/doc/design/images/graph_construction_example_all.png index 181187503472d15779b87284105841168b3945c4..261611a5721f9aa97874f7e6d897fe48cf667db2 100644 Binary files a/doc/design/images/graph_construction_example_all.png and b/doc/design/images/graph_construction_example_all.png differ diff --git a/doc/design/images/graph_construction_example_forward_backward.png b/doc/design/images/graph_construction_example_forward_backward.png index 3049a9315fd616464dec54e33064cb75598ca536..4c69687f4a6a181138f3df72ce5e8aa48487b5be 100644 Binary files a/doc/design/images/graph_construction_example_forward_backward.png and b/doc/design/images/graph_construction_example_forward_backward.png differ diff --git a/doc/design/images/graph_construction_example_forward_only.png b/doc/design/images/graph_construction_example_forward_only.png index 25d19088cbf0b5f68cf734f2ff21eba8af4a2860..e668c16e0cac73acb4e5dc2b1827557ae77126b4 100644 Binary files a/doc/design/images/graph_construction_example_forward_only.png and b/doc/design/images/graph_construction_example_forward_only.png differ diff --git a/paddle/api/CMakeLists.txt b/paddle/api/CMakeLists.txt index d7b3d2bdec1687425df804c0d56d568241f9e8b0..d6b8464100d4497876aa3f6f7cbc666aafae4bfc 100644 --- a/paddle/api/CMakeLists.txt +++ b/paddle/api/CMakeLists.txt @@ -26,7 +26,7 @@ FILE(GLOB PY_PADDLE_PYTHON_FILES ${PADDLE_SOURCE_DIR}/paddle/py_paddle/*.py) SET_SOURCE_FILES_PROPERTIES(Paddle.i PROPERTIES CPLUSPLUS ON) SET(CMAKE_SWIG_OUTDIR ${CMAKE_CURRENT_BINARY_DIR}) -SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-parentheses-equality -Wno-missing-field-initializers -Wno-self-assign") +SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-parentheses-equality -Wno-missing-field-initializers -Wno-self-assign -ftls-model=global-dynamic") SET(SWIG_MODULE_swig_paddle_EXTRA_DEPS paddle_parameter diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 6b34c3bbcfbdb0c36381df7de4dd227e317829e5..184ec65d3fa5526b9ec32b376f1a10ca8ca69a6d 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -42,12 +42,14 @@ add_custom_command(TARGET framework_py_proto POST_BUILD cc_library(backward SRCS backward.cc DEPS net_op) cc_test(backward_test SRCS backward_test.cc DEPS backward recurrent_op device_context) -cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto backward ${GLOB_OP_LIB}) -#if(WITH_GPU) -# nv_test(executor_test SRCS executor_test.cc DEPS executor) -#else() -# cc_test(executor_test SRCS executor_test.cc DEPS executor) -#endif() +cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto backward) +set(EXECUTOR_TEST_OP elementwise_add_op gaussian_random_op feed_op fetch_op + mul_op sum_op squared_l2_distance_op fill_constant_op sgd_op) +if(WITH_GPU) + nv_test(executor_test SRCS executor_test.cc DEPS executor ${EXECUTOR_TEST_OP}) +else() + cc_test(executor_test SRCS executor_test.cc DEPS executor ${EXECUTOR_TEST_OP}) +endif() cc_library(tensor_array SRCS tensor_array.cc DEPS lod_tensor) cc_test(tensor_array_test SRCS tensor_array_test.cc DEPS tensor_array place) diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index 3b7cbcd98927be829d185590147adf74cd3d10d1..8fd8c826187e3e3f830fdb318c4edacbad8b7333 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -451,6 +451,7 @@ TEST(Backward, default_attribute) { op->SetInput("X", {"x"}); op->SetInput("Y", {"y"}); op->SetOutput("Out", {"out"}); + op->CheckAttrs(); AppendBackward(program, {}); diff --git a/paddle/framework/block_desc.cc b/paddle/framework/block_desc.cc index 509aa235d3ee226adef15f08f5785866700499f1..b77d5525d4508056c9d6d487e63e500265e1d700 100644 --- a/paddle/framework/block_desc.cc +++ b/paddle/framework/block_desc.cc @@ -91,9 +91,5 @@ BlockDescBind *BlockDescBind::ParentBlock() const { return prog_->Block(static_cast(this->desc_->parent_idx())); } -void OpDescBind::SetBlockAttr(const std::string &name, BlockDescBind &block) { - BlockDesc *desc = block.RawPtr(); - this->attrs_[name] = desc; -} } // namespace framework } // namespace paddle diff --git a/paddle/framework/executor_test.cc b/paddle/framework/executor_test.cc index 7f6d8fe6a4aec9fdc39b4ffc0837a03e355ec937..eaa9c9414b631a986af1ec2de0ebad84ed27f983 100644 --- a/paddle/framework/executor_test.cc +++ b/paddle/framework/executor_test.cc @@ -25,6 +25,16 @@ limitations under the License. */ #include "paddle/framework/op_registry.h" #include "paddle/framework/operator.h" +USE_OP(elementwise_add); +USE_OP(gaussian_random); +USE_OP(feed); +USE_OP(fetch); +USE_OP(mul); +USE_OP(sum); +USE_OP(squared_l2_distance); +USE_OP(fill_constant); +USE_OP(sgd); + using namespace paddle::platform; using namespace paddle::framework; @@ -49,6 +59,7 @@ void AddOp(const std::string& type, const VariableNameMap& inputs, op->SetOutput(kv.first, kv.second); } op->SetAttrMap(attrs); + op->CheckAttrs(); } // Tensors in feed value variable will only be in CPUPlace diff --git a/paddle/framework/op_desc.cc b/paddle/framework/op_desc.cc index d3c11ad60a0f9319329a59c16bfc4668cd75b7ae..a5d515bbca729220ca6df5fa07d02f1b3f025109 100644 --- a/paddle/framework/op_desc.cc +++ b/paddle/framework/op_desc.cc @@ -100,6 +100,12 @@ void OpDescBind::SetAttr(const std::string &name, const Attribute &v) { need_update_ = true; } +void OpDescBind::SetBlockAttr(const std::string &name, BlockDescBind &block) { + BlockDesc *desc = block.RawPtr(); + this->attrs_[name] = desc; + need_update_ = true; +} + void OpDescBind::SetAttrMap( const std::unordered_map &attr_map) { attrs_ = attr_map; diff --git a/paddle/framework/op_registry.cc b/paddle/framework/op_registry.cc index b118edae17430c8a4dd5c96a2a0c675766e08166..94f75b0f309417e52bb5ad3117797d8c25233257 100644 --- a/paddle/framework/op_registry.cc +++ b/paddle/framework/op_registry.cc @@ -62,11 +62,6 @@ std::unique_ptr OpRegistry::CreateOp(const OpDescBind& op_desc) { std::vector> OpRegistry::CreateGradOpDescs( OpDescBind* op_desc) { auto& info = OpInfoMap::Instance().Get(op_desc->Type()); - - if (info.Checker() != nullptr) { - info.Checker()->Check(*op_desc->MutableAttrMap()); - } - return info.grad_op_maker_(*op_desc); } diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 15f80b57206c90f689acfdcac60a0d9011025fc0..97a142d5f1661704fede858b28ff0d5487c66fab 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -289,6 +289,15 @@ class ExecutionContext { return device_context_; } +#ifdef PADDLE_WITH_CUDA + const platform::CUDADeviceContext& cuda_device_context() const { + PADDLE_ENFORCE(platform::is_gpu_place(device_context_.GetPlace())); + auto cuda_ctx = + reinterpret_cast(&device_context_); + return *cuda_ctx; + } +#endif + private: const OperatorBase& op_; const Scope& scope_; diff --git a/paddle/framework/shape_inference.h b/paddle/framework/shape_inference.h index 64aab16ae54d34fd614add348c7c420b4a8f771d..b93f980cf6d279d18388b9637a2ff45d797ca78e 100644 --- a/paddle/framework/shape_inference.h +++ b/paddle/framework/shape_inference.h @@ -19,9 +19,6 @@ limitations under the License. */ namespace paddle { namespace framework { -// TODO(longfei): Once after both CompileTimeInferShapeContext and -// RuntimeInferShapeContext get merged, we can rename InferShapeContext into -// InferShapeContext so to replace the current InferShapeContext. class InferShapeContext { public: virtual ~InferShapeContext() {} diff --git a/paddle/framework/var_desc.h b/paddle/framework/var_desc.h index 464fece85fe5c674690c2034054e551f14db2138..44368795645664a343e2706fb670f104a42c5c9f 100644 --- a/paddle/framework/var_desc.h +++ b/paddle/framework/var_desc.h @@ -34,6 +34,7 @@ inline std::vector RepeatedToVector( template inline void VectorToRepeated(const std::vector &vec, RepeatedField *repeated_field) { + repeated_field->Clear(); repeated_field->Reserve(vec.size()); for (const auto &elem : vec) { *repeated_field->Add() = elem; @@ -44,6 +45,7 @@ inline void VectorToRepeated(const std::vector &vec, template inline void VectorToRepeated(const std::vector &vec, RepeatedField *repeated_field) { + repeated_field->Clear(); repeated_field->Reserve(vec.size()); for (auto elem : vec) { *repeated_field->Add() = elem; diff --git a/paddle/operators/adam_op.cc b/paddle/operators/adam_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..293b37b7750427cb88efb6dfd5a02dcf7ede24ac --- /dev/null +++ b/paddle/operators/adam_op.cc @@ -0,0 +1,144 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/adam_op.h" + +namespace paddle { +namespace operators { + +class AdamOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Param"), + "Input(Param) of AdamOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Grad"), + "Input(Grad) of AdamOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Moment1"), + "Input(Moment1) of AdamOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Moment2"), + "Input(Moment2) of AdamOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("LearningRate"), + "Input(LearningRate) of AdamOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Beta1Pow"), + "Input(Beta1Pow) of AdamOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Beta2Pow"), + "Input(Beta2Pow) of AdamOp should not be null."); + + PADDLE_ENFORCE(ctx->HasOutput("ParamOut"), + "Output(ParamOut) of AdamOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Moment1Out"), + "Output(Moment1Out) of AdamOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Moment2Out"), + "Output(Moment2Out) of AdamOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Beta1PowOut"), + "Output(Beta1PowOut) of AdamOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Beta2PowOut"), + "Output(Beta2PowOut) of AdamOp should not be null."); + + auto lr_dims = ctx->GetInputDim("LearningRate"); + PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1, + "Learning rate should have 1 dimension"); + auto beta1_pow_dims = ctx->GetInputDim("Beta1Pow"); + PADDLE_ENFORCE_EQ(framework::product(beta1_pow_dims), 1, + "Beta1 power accumulator should have 1 dimension"); + auto beta2_pow_dims = ctx->GetInputDim("Beta2Pow"); + PADDLE_ENFORCE_EQ(framework::product(beta1_pow_dims), 1, + "Beta1 power accumulator should have 1 dimension"); + + auto param_dims = ctx->GetInputDim("Param"); + PADDLE_ENFORCE_EQ( + param_dims, ctx->GetInputDim("Grad"), + "Param and Grad input of AdamOp should have same dimension"); + PADDLE_ENFORCE_EQ( + param_dims, ctx->GetInputDim("Moment1"), + "Param and Moment input of AdamOp should have same dimension"); + PADDLE_ENFORCE_EQ( + param_dims, ctx->GetInputDim("Moment2"), + "Param and InfNorm input of AdamOp should have same dimension"); + + ctx->SetOutputDim("ParamOut", param_dims); + ctx->SetOutputDim("Moment1Out", param_dims); + ctx->SetOutputDim("Moment2Out", param_dims); + ctx->SetOutputDim("Beta1PowOut", beta1_pow_dims); + ctx->SetOutputDim("Beta2PowOut", beta2_pow_dims); + } +}; + +class AdamOpMaker : public framework::OpProtoAndCheckerMaker { + public: + AdamOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("Param", "(Tensor) Input parameter"); + AddInput("Grad", "(Tensor) Input gradient"); + AddInput("LearningRate", "(Tensor) Learning rate"); + AddInput("Moment1", "(Tensor) Input first moment"); + AddInput("Moment2", "(Tensor) Input second moment"); + AddInput("Beta1Pow", "(Tensor) Input beta1 power accumulator"); + AddInput("Beta2Pow", "(Tensor) Input beta2 power accumulator"); + + AddOutput("ParamOut", "(Tensor) Output parameter"); + AddOutput("Moment1Out", "(Tensor) Output first moment"); + AddOutput("Moment2Out", "(Tensor) Output second moment"); + AddOutput("Beta1PowOut", "(Tensor) Output beta1 power accumulator"); + AddOutput("Beta2PowOut", "(Tensor) Output beta2 power accumulator"); + + AddAttr("beta1", + "(float, default 0.9) " + "Exponential decay rate for the " + "first moment estimates.") + .SetDefault(0.9f); + AddAttr("beta2", + "(float, default 0.999) " + "exponential decay rate for the " + "second moment estimates.") + .SetDefault(0.999f); + AddAttr("epsilon", + "(float, default 1.0e-8) " + "Constant for numerical stability") + .SetDefault(1.0e-8f); + + AddComment(R"DOC( +Adam Updates Operator. + +This implements the Adam optimizer from Section 2 of the Adam +paper[1]. Adam is a first-order gradient-based optimization +method based on adaptive estimates of lower-order moments. + +Adam updates: + +moment1_out = beta1 * moment1 + (1 − beta1) * grad +moment2_out = beta2 * moment2 + (1 − beta2) * grad * grad +beta1_pow_out = beta1_pow * beta1 +beta2_pow_out = beta2_pow * beta2 +learning_rate_t = learning_rate_t * + sqrt(1 - beta2_pow_out) / (1 - beta1_pow_out) +param_out = param - learning_rate_t * moment1/ (sqrt(moment2) + epsilon) + +References: + [1] Adam: A Method for Stochastic Optimization + (https://arxiv.org/abs/1412.6980) + +)DOC"); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(adam, ops::AdamOp, ops::AdamOpMaker); +REGISTER_OP_CPU_KERNEL(adam, + ops::AdamOpKernel); diff --git a/paddle/operators/adam_op.cu b/paddle/operators/adam_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..a3def912e540454275350209435eb01ae2151331 --- /dev/null +++ b/paddle/operators/adam_op.cu @@ -0,0 +1,20 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/operators/adam_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(adam, + ops::AdamOpKernel); diff --git a/paddle/operators/adam_op.h b/paddle/operators/adam_op.h new file mode 100644 index 0000000000000000000000000000000000000000..789c2f14b32478bf9ddc967fc5725bcf65ed2146 --- /dev/null +++ b/paddle/operators/adam_op.h @@ -0,0 +1,82 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class AdamOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto param_out_tensor = ctx.Output("ParamOut"); + auto moment1_out_tensor = ctx.Output("Moment1Out"); + auto moment2_out_tensor = ctx.Output("Moment2Out"); + auto beta1_pow_out_tensor = ctx.Output("Beta1PowOut"); + auto beta2_pow_out_tensor = ctx.Output("Beta2PowOut"); + + param_out_tensor->mutable_data(ctx.GetPlace()); + moment1_out_tensor->mutable_data(ctx.GetPlace()); + moment2_out_tensor->mutable_data(ctx.GetPlace()); + beta1_pow_out_tensor->mutable_data(ctx.GetPlace()); + beta2_pow_out_tensor->mutable_data(ctx.GetPlace()); + + float beta1 = ctx.Attr("beta1"); + float beta2 = ctx.Attr("beta2"); + float epsilon = ctx.Attr("epsilon"); + + auto param = framework::EigenVector::Flatten( + *ctx.Input("Param")); + auto grad = framework::EigenVector::Flatten( + *ctx.Input("Grad")); + auto moment1 = framework::EigenVector::Flatten( + *ctx.Input("Moment1")); + auto moment2 = framework::EigenVector::Flatten( + *ctx.Input("Moment2")); + auto lr = framework::EigenVector::Flatten( + *ctx.Input("LearningRate")); + auto beta1_pow = framework::EigenVector::Flatten( + *ctx.Input("Beta1Pow")); + auto beta2_pow = framework::EigenVector::Flatten( + *ctx.Input("Beta2Pow")); + auto param_out = framework::EigenVector::Flatten(*param_out_tensor); + auto moment1_out = framework::EigenVector::Flatten(*moment1_out_tensor); + auto moment2_out = framework::EigenVector::Flatten(*moment2_out_tensor); + auto beta1_pow_out = + framework::EigenVector::Flatten(*beta1_pow_out_tensor); + auto beta2_pow_out = + framework::EigenVector::Flatten(*beta2_pow_out_tensor); + auto place = ctx.GetEigenDevice(); + + moment1_out.device(place) = beta1 * moment1 + (1 - beta1) * grad; + moment2_out.device(place) = beta2 * moment2 + (1 - beta2) * grad.square(); + beta1_pow_out.device(place) = beta1_pow * beta1; + beta2_pow_out.device(place) = beta2_pow * beta2; + // All of these are tensors of 1 element + auto lr_t = lr * (1 - beta2_pow_out).sqrt() / (1 - beta1_pow_out); + // Eigen does not support automatic broadcast + // Get dimensions of moment vector to broadcast lr_t + Eigen::DSizes m_dsize(moment1_out_tensor->numel()); + param_out.device(place) = + param - + lr_t.broadcast(m_dsize) * + (moment1_out / (moment2_out.sqrt() + epsilon)); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/conv2d_op.cc b/paddle/operators/conv2d_op.cc index 6325d4248f10ea8a12ae5398d9fe0e579db3f7ae..1acb8415d0691df77047806d3c81b51cbb8c59f3 100644 --- a/paddle/operators/conv2d_op.cc +++ b/paddle/operators/conv2d_op.cc @@ -12,111 +12,91 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/operators/gemm_conv2d_op.h" +#include "paddle/operators/conv2d_op.h" namespace paddle { namespace operators { -int outputSize(int input_size, int filter_size, int padding, int stride) { - int output_size = (input_size - filter_size + 2 * padding) / stride + 1; - return output_size; +void Conv2DOp::InferShape(framework::InferShapeContext* ctx) const { + PADDLE_ENFORCE(ctx->HasInput("Input"), + "Input(Input) of Conv2DOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Filter"), + "Input(Filter) of Conv2DOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Output"), + "Output(Output) of Conv2DOp should not be null."); + + auto in_dims = ctx->GetInputDim("Input"); + auto filter_dims = ctx->GetInputDim("Filter"); + std::vector strides = ctx->Attrs().Get>("strides"); + std::vector paddings = ctx->Attrs().Get>("paddings"); + int groups = ctx->Attrs().Get("groups"); + int input_channels = in_dims[1]; + int output_channels = filter_dims[0]; + + PADDLE_ENFORCE_EQ(in_dims.size(), 4, "Conv2DOp input should be 4-D."); + PADDLE_ENFORCE_EQ(filter_dims.size(), 4, "Conv2DOp filter should be 4-D."); + PADDLE_ENFORCE_EQ(input_channels, filter_dims[1] * groups, + "The number of input channels should be equal to filter " + "channels * groups."); + PADDLE_ENFORCE_EQ( + output_channels % groups, 0, + "The number of output channels should be divided by groups."); + + auto output_height = + OutputSize(in_dims[2], filter_dims[2], paddings[0], strides[0]); + auto output_width = + OutputSize(in_dims[3], filter_dims[3], paddings[1], strides[1]); + ctx->SetOutputDim("Output", + {in_dims[0], filter_dims[0], output_height, output_width}); } -class Conv2DOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("Input"), - "Input(Input) of Conv2DOp should not be null."); - PADDLE_ENFORCE(ctx->HasInput("Filter"), - "Input(Filter) of Conv2DOp should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("Output"), - "Output(Output) of Conv2DOp should not be null."); - - auto in_dims = ctx->GetInputDim("Input"); - auto filter_dims = ctx->GetInputDim("Filter"); - std::vector strides = ctx->Attrs().Get>("strides"); - std::vector paddings = ctx->Attrs().Get>("paddings"); - int groups = ctx->Attrs().Get("groups"); - int input_channels = in_dims[1]; - int output_channels = filter_dims[0]; - - PADDLE_ENFORCE_EQ(in_dims.size(), 4, "Conv2DOp input should be 4-D."); - PADDLE_ENFORCE_EQ(filter_dims.size(), 4, "Conv2DOp filter should be 4-D."); - PADDLE_ENFORCE_EQ(input_channels, filter_dims[1] * groups, - "The number of input channels should be equal to filter " - "channels * groups."); - PADDLE_ENFORCE_EQ( - output_channels % groups, 0, - "The number of output channels should be divided by groups."); - - auto output_height = - outputSize(in_dims[2], filter_dims[2], paddings[0], strides[0]); - auto output_width = - outputSize(in_dims[3], filter_dims[3], paddings[1], strides[1]); - ctx->SetOutputDim( - "Output", {in_dims[0], filter_dims[0], output_height, output_width}); - } -}; - -class Conv2DOpMaker : public framework::OpProtoAndCheckerMaker { - public: - Conv2DOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput( - "Input", - "The input tensor of convolution operator. " - "The format of input tensor is NCHW. Where N is batch size, C is the " - "number of channels, H and W is the height and width of image."); - AddInput( - "Filter", - "The filter tensor of convolution operator." - "The format of the filter tensor is MCHW, where M is the number of " - "output image channels, C is the number of input image channels, " - "H and W is height and width of filter. " - "If the groups attribute is greater than 1, C equal the number of " - "input image channels divided by the groups."); - AddOutput("Output", - "The output tensor of convolution operator." - "The format of output tensor is also NCHW."); - AddAttr>("strides", "strides of convolution operator.") - .SetDefault({1, 1}); - AddAttr>("paddings", "paddings of convolution operator.") - .SetDefault({0, 0}); - AddAttr( - "groups", - "group size of convolution operator. " - "Refer to grouped convolution in Alex Krizhevsky's paper: " - "when group=2, the first half of the filters are only connected to the " - "first half of the input channels, and the second half only connected " - "to the second half.") - .SetDefault(1); - AddComment(R"DOC( +Conv2DOpMaker::Conv2DOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput( + "Input", + "The input tensor of convolution operator. " + "The format of input tensor is NCHW. Where N is batch size, C is the " + "number of channels, H and W is the height and width of image."); + AddInput("Filter", + "The filter tensor of convolution operator." + "The format of the filter tensor is MCHW, where M is the number of " + "output image channels, C is the number of input image channels, " + "H and W is height and width of filter. " + "If the groups attribute is greater than 1, C equal the number of " + "input image channels divided by the groups."); + AddOutput("Output", + "The output tensor of convolution operator." + "The format of output tensor is also NCHW."); + AddAttr>("strides", "strides of convolution operator.") + .SetDefault({1, 1}); + AddAttr>("paddings", "paddings of convolution operator.") + .SetDefault({0, 0}); + AddAttr( + "groups", + "group size of convolution operator. " + "Refer to grouped convolution in Alex Krizhevsky's paper: " + "when group=2, the first half of the filters are only connected to the " + "first half of the input channels, and the second half only connected " + "to the second half.") + .SetDefault(1); + AddComment(R"DOC( The convolution operation calculates the output based on the input, filter and strides, paddings, groups parameters. The size of each dimension of the parameters is checked in the infer-shape. )DOC"); - } -}; - -class Conv2DOpGrad : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; +} - protected: - void InferShape(framework::InferShapeContext* ctx) const override { - auto in_dims = ctx->GetInputDim("Input"); - auto filter_dims = ctx->GetInputDim("Filter"); - if (ctx->HasOutput(framework::GradVarName("Input"))) { - ctx->SetOutputDim(framework::GradVarName("Input"), in_dims); - } - if (ctx->HasOutput(framework::GradVarName("Filter"))) { - ctx->SetOutputDim(framework::GradVarName("Filter"), filter_dims); - } +void Conv2DOpGrad::InferShape(framework::InferShapeContext* ctx) const { + auto in_dims = ctx->GetInputDim("Input"); + auto filter_dims = ctx->GetInputDim("Filter"); + if (ctx->HasOutput(framework::GradVarName("Input"))) { + ctx->SetOutputDim(framework::GradVarName("Input"), in_dims); } -}; + if (ctx->HasOutput(framework::GradVarName("Filter"))) { + ctx->SetOutputDim(framework::GradVarName("Filter"), filter_dims); + } +} } // namespace operators } // namespace paddle diff --git a/paddle/operators/conv2d_op.cu b/paddle/operators/conv2d_op.cu index 5df818ba0496a65502dde37fd1397ec56f8c1101..c697c9466d34c29af6976f3a4d2d0a24ba778ceb 100644 --- a/paddle/operators/conv2d_op.cu +++ b/paddle/operators/conv2d_op.cu @@ -12,7 +12,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/operators/gemm_conv2d_op.h" +#include "paddle/operators/conv2d_op.h" namespace ops = paddle::operators; diff --git a/paddle/operators/gemm_conv2d_op.h b/paddle/operators/conv2d_op.h similarity index 90% rename from paddle/operators/gemm_conv2d_op.h rename to paddle/operators/conv2d_op.h index 323e3f7c3bd506c6b63bf4d1152384649f5da575..7ebdbe81cbbaf59a60eb3dac0f570d70fc85d6ef 100644 --- a/paddle/operators/gemm_conv2d_op.h +++ b/paddle/operators/conv2d_op.h @@ -24,6 +24,38 @@ namespace operators { using Tensor = framework::Tensor; +// Base convolution operator definations for other conv +// like operators to reuse the implementation. +inline int OutputSize(int input_size, int filter_size, int padding, + int stride) { + int output_size = (input_size - filter_size + 2 * padding) / stride + 1; + return output_size; +} + +// Define Op classes in .h file so that other conv +// operator implementations can reuse the code. +class Conv2DOpMaker : public framework::OpProtoAndCheckerMaker { + public: + Conv2DOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker); +}; + +class Conv2DOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override; +}; + +class Conv2DOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override; +}; + template class GemmConv2DKernel : public framework::OpKernel { public: @@ -74,7 +106,6 @@ class GemmConv2DKernel : public framework::OpKernel { framework::DDim output_matrix_shape = {output_channels, output_height * output_width}; - // convolution operator: im2col + gemm int in_step = input_channels / groups; int out_step = output_channels / groups; diff --git a/paddle/operators/conv_cudnn_op.cc b/paddle/operators/conv_cudnn_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..4288f300dd5b0464f2b4394cdb0b44f93060ae74 --- /dev/null +++ b/paddle/operators/conv_cudnn_op.cc @@ -0,0 +1,47 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/conv2d_op.h" + +namespace paddle { +namespace operators { + +class CudnnConvOpMaker : public Conv2DOpMaker { + public: + CudnnConvOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) + : Conv2DOpMaker(proto, op_checker) { + AddAttr>("dilations", "dilations of convolution operator.") + .SetDefault(std::vector{1, 1}); + AddAttr("workspace_size_MB", + "workspace size for cudnn, in MB, " + "workspace is a section of GPU memory which will be " + "allocated/freed each time the operator runs, larger " + "workspace size can increase performance but also requires " + "better hardward. This size should be carefully setted.") + .SetDefault(4096); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(conv_cudnn, ops::Conv2DOp, ops::CudnnConvOpMaker, conv_cudnn_grad, + ops::Conv2DOpGrad); +REGISTER_OP_CPU_KERNEL( + conv_cudnn, ops::GemmConv2DKernel); +REGISTER_OP_CPU_KERNEL( + conv_cudnn_grad, + ops::GemmConvGrad2DKernel); diff --git a/paddle/operators/conv_cudnn_op.cu b/paddle/operators/conv_cudnn_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..366d0323b840c338dd6ba5b28bdb29fd135fe91a --- /dev/null +++ b/paddle/operators/conv_cudnn_op.cu @@ -0,0 +1,277 @@ +/* Copyright (c) 2016 PaddlePaddle Authors All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" +#include "paddle/memory/memory.h" +#include "paddle/operators/conv2d_op.h" +#include "paddle/platform/assert.h" +#include "paddle/platform/cudnn_helper.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using ScopedTensorDescriptor = platform::ScopedTensorDescriptor; +using ScopedFilterDescriptor = platform::ScopedFilterDescriptor; +using ScopedConvolutionDescriptor = platform::ScopedConvolutionDescriptor; +using DataLayout = platform::DataLayout; +using CUDADeviceContext = platform::CUDADeviceContext; + +static constexpr size_t kCONV_CUDNN_WORKSPACE_LIMIT_BYTES = 1024 * 1024 * 1024; + +// NOTE: framework::vectorize converts to type int64_t +// which does not fit cudnn inputs. +std::vector Dims2Vector(const framework::DDim& dims) { + std::vector ret; + for (int i = 0; i < dims.size(); i++) { + ret.push_back(dims[i]); + } + return ret; +} + +template +class CudnnConvOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "It must use GPUPlace."); + auto* input = ctx.Input("Input"); + auto* filter = ctx.Input("Filter"); + auto* output = ctx.Output("Output"); + + std::vector strides = ctx.Attr>("strides"); + std::vector paddings = ctx.Attr>("paddings"); + std::vector dilations = ctx.Attr>("dilations"); + int groups = ctx.Attr("groups"); + int user_workspace_size = ctx.Attr("workspace_size_MB"); + + const T* input_data = input->data(); + const T* filter_data = filter->data(); + T* output_data = output->mutable_data(ctx.GetPlace()); + + // ------------------- cudnn descriptors --------------------- + ScopedTensorDescriptor input_desc; + ScopedTensorDescriptor output_desc; + ScopedFilterDescriptor filter_desc; + ScopedConvolutionDescriptor conv_desc; + DataLayout layout = DataLayout::kNCHW; + + cudnnTensorDescriptor_t cudnn_input_desc = + input_desc.descriptor(layout, Dims2Vector(input->dims()), groups); + cudnnTensorDescriptor_t cudnn_output_desc = + output_desc.descriptor(layout, Dims2Vector(output->dims()), groups); + cudnnFilterDescriptor_t cudnn_filter_desc = + filter_desc.descriptor(layout, Dims2Vector(filter->dims()), groups); + cudnnConvolutionDescriptor_t cudnn_conv_desc = + conv_desc.descriptor(paddings, strides, dilations); + + int input_channels = input->dims()[1]; + int input_height = input->dims()[2]; + int input_width = input->dims()[3]; + int output_channels = output->dims()[1]; + int output_height = output->dims()[2]; + int output_width = output->dims()[3]; + + int group_offset_in = input_channels / groups * input_height * input_width; + int group_offset_out = + output_channels / groups * output_height * output_width; + int group_offset_filter = filter->numel() / groups; + // ------------------- cudnn conv workspace --------------------- + void* cudnn_workspace = nullptr; + size_t workspace_size_in_bytes; // final workspace to allocate. + size_t workspace_size_limit = kCONV_CUDNN_WORKSPACE_LIMIT_BYTES; + if (user_workspace_size > 0) { + workspace_size_limit = user_workspace_size * 1024 * 1024; + } + // ------------------- cudnn conv algorithm --------------------- + cudnnConvolutionFwdAlgo_t algo; + auto handle = ctx.cuda_device_context().cudnn_handle(); + + PADDLE_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm( + handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc, + cudnn_output_desc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, + workspace_size_limit, &algo)); + // get workspace size able to allocate + PADDLE_ENFORCE(platform::dynload::cudnnGetConvolutionForwardWorkspaceSize( + handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc, + cudnn_output_desc, algo, &workspace_size_in_bytes)); + // Allocate on GPU memory + platform::GPUPlace gpu = boost::get(ctx.GetPlace()); + cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes); + // ------------------- cudnn conv forward --------------------- + T alpha = 1.0f, beta = 0.0f; + for (int i = 0; i < groups; i++) { + PADDLE_ENFORCE(platform::dynload::cudnnConvolutionForward( + handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in, + cudnn_filter_desc, filter_data + i * group_offset_filter, + cudnn_conv_desc, algo, cudnn_workspace, workspace_size_in_bytes, + &beta, cudnn_output_desc, output_data + i * group_offset_out)); + } + // Release the cudnn workspace + paddle::memory::Free(gpu, cudnn_workspace); + } +}; + +template +class CudnnConvGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "It must use GPUPlace."); + auto input = ctx.Input("Input"); + auto filter = ctx.Input("Filter"); + auto output_grad = ctx.Input(framework::GradVarName("Output")); + auto input_grad = ctx.Output(framework::GradVarName("Input")); + auto filter_grad = ctx.Output(framework::GradVarName("Filter")); + + const T* input_data = input->data(); + const T* output_grad_data = output_grad->data(); + const T* filter_data = filter->data(); + + std::vector strides = ctx.Attr>("strides"); + std::vector paddings = ctx.Attr>("paddings"); + std::vector dilations = ctx.Attr>("dilations"); + int groups = ctx.Attr("groups"); + int user_workspace_size = ctx.Attr("workspace_size_MB"); + + // ------------------- cudnn descriptors --------------------- + ScopedTensorDescriptor input_desc; + ScopedTensorDescriptor output_grad_desc; + ScopedTensorDescriptor input_grad_desc; + + ScopedFilterDescriptor filter_desc; + ScopedFilterDescriptor filter_grad_desc; + ScopedConvolutionDescriptor conv_desc; + DataLayout layout = DataLayout::kNCHW; + + cudnnTensorDescriptor_t cudnn_input_desc = + input_desc.descriptor(layout, Dims2Vector(input->dims()), groups); + cudnnTensorDescriptor_t cudnn_output_grad_desc = + output_grad_desc.descriptor(layout, Dims2Vector(output_grad->dims()), + groups); + cudnnFilterDescriptor_t cudnn_filter_desc = + filter_desc.descriptor(layout, Dims2Vector(filter->dims()), groups); + cudnnTensorDescriptor_t cudnn_input_grad_desc = nullptr; + cudnnFilterDescriptor_t cudnn_filter_grad_desc = nullptr; + + cudnnConvolutionDescriptor_t cudnn_conv_desc = + conv_desc.descriptor(paddings, strides, dilations); + + int input_channels = input->dims()[1]; + int input_height = input->dims()[2]; + int input_width = input->dims()[3]; + int output_grad_channels = filter->dims()[0]; + int output_grad_height = output_grad->dims()[2]; + int output_grad_width = output_grad->dims()[3]; + + int group_offset_in = input_channels / groups * input_height * input_width; + int group_offset_out = + output_grad_channels / groups * output_grad_height * output_grad_width; + int group_offset_filter = filter->numel() / groups; + // ------------------- cudnn backward algorithm --------------------- + cudnnConvolutionBwdDataAlgo_t data_algo; + cudnnConvolutionBwdFilterAlgo_t filter_algo; + size_t workspace_size_in_bytes = 0, tmp_size = 0; + size_t workspace_size_limit = kCONV_CUDNN_WORKSPACE_LIMIT_BYTES; + if (user_workspace_size > 0) { + workspace_size_limit = user_workspace_size * 1024 * 1024; + } + + auto handle = ctx.cuda_device_context().cudnn_handle(); + if (input_grad) { + cudnn_input_grad_desc = input_grad_desc.descriptor( + layout, Dims2Vector(input_grad->dims()), groups); + PADDLE_ENFORCE( + platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm( + handle, cudnn_filter_desc, + // dyDesc: Handle to the previously initialized input differential + // tensor descriptor. + cudnn_output_grad_desc, cudnn_conv_desc, + // dxDesc: Handle to the previously initialized output tensor + // descriptor. + cudnn_input_grad_desc, + CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT, + workspace_size_limit, &data_algo)); + PADDLE_ENFORCE( + platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize( + handle, cudnn_filter_desc, cudnn_output_grad_desc, + cudnn_conv_desc, cudnn_input_grad_desc, data_algo, &tmp_size)); + workspace_size_in_bytes = std::max(workspace_size_in_bytes, tmp_size); + } + + if (filter_grad) { + cudnn_filter_grad_desc = filter_grad_desc.descriptor( + layout, Dims2Vector(filter_grad->dims()), groups); + PADDLE_ENFORCE( + platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm( + handle, cudnn_input_desc, cudnn_output_grad_desc, cudnn_conv_desc, + cudnn_filter_desc, + CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT, + workspace_size_limit, &filter_algo)); + + PADDLE_ENFORCE( + platform::dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize( + handle, cudnn_input_desc, cudnn_output_grad_desc, cudnn_conv_desc, + cudnn_filter_desc, filter_algo, &tmp_size)); + workspace_size_in_bytes = std::max(workspace_size_in_bytes, tmp_size); + } + // ------------------- cudnn conv workspace --------------------- + // Already on GPU + void* cudnn_workspace = nullptr; + platform::GPUPlace gpu = boost::get(ctx.GetPlace()); + cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes); + // ------------------- cudnn conv backward data --------------------- + // FIXME(typhoonzero): template type T may not be the same as cudnn call. + T alpha = 1.0f, beta = 0.0f; + if (input_grad) { + T* input_grad_data = input_grad->mutable_data(ctx.GetPlace()); + auto t = framework::EigenVector::Flatten(*input_grad); + t.device(ctx.GetEigenDevice()) = + t.constant(static_cast(0)); + for (int i = 0; i < groups; i++) { + PADDLE_ENFORCE(platform::dynload::cudnnConvolutionBackwardData( + handle, &alpha, cudnn_filter_desc, + filter_data + i * group_offset_filter, cudnn_output_grad_desc, + output_grad_data + i * group_offset_out, cudnn_conv_desc, data_algo, + cudnn_workspace, workspace_size_in_bytes, &beta, + cudnn_input_grad_desc, input_grad_data + i * group_offset_in)); + } + } + // ------------------- cudnn conv backward filter --------------------- + if (filter_grad) { + T* filter_grad_data = filter_grad->mutable_data(ctx.GetPlace()); + auto t = framework::EigenVector::Flatten(*filter_grad); + t.device(ctx.GetEigenDevice()) = + t.constant(static_cast(0)); + for (int i = 0; i < groups; i++) { + PADDLE_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter( + handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in, + cudnn_output_grad_desc, output_grad_data + i * group_offset_out, + cudnn_conv_desc, filter_algo, cudnn_workspace, + workspace_size_in_bytes, &beta, cudnn_filter_grad_desc, + filter_grad_data + i * group_offset_filter)); + } + } + // Release the cudnn workspace + paddle::memory::Free(gpu, cudnn_workspace); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OP_GPU_KERNEL(conv_cudnn, paddle::operators::CudnnConvOpKernel); +REGISTER_OP_GPU_KERNEL(conv_cudnn_grad, + paddle::operators::CudnnConvGradOpKernel); diff --git a/paddle/operators/decayed_adagrad_op.cc b/paddle/operators/decayed_adagrad_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..7f583f18c8c6ee5025f6525306f9323fb329b030 --- /dev/null +++ b/paddle/operators/decayed_adagrad_op.cc @@ -0,0 +1,96 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/decayed_adagrad_op.h" + +namespace paddle { +namespace operators { + +class DecayedAdagradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Param"), + "Input(Param) of DecayedAdagradOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Grad"), + "Input(Grad) of DecayedAdagradOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Moment"), + "Input(Moment) of DecayedAdagradOp should not be null."); + PADDLE_ENFORCE( + ctx->HasInput("LearningRate"), + "Input(LearningRate) of DecayedAdagradOp should not be null."); + + PADDLE_ENFORCE(ctx->HasOutput("ParamOut"), + "Output(ParamOut) of DecayedAdagradOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("MomentOut"), + "Output(MomentOut) of DecayedAdagradOp should not be null."); + + auto lr_dims = ctx->GetInputDim("LearningRate"); + PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1, + "LearningRate should have one element"); + auto param_dims = ctx->GetInputDim("Param"); + PADDLE_ENFORCE_EQ(param_dims, ctx->GetInputDim("Grad"), + "Param and Grad input of DecayedAdagradOp should have " + "the same dimension."); + PADDLE_ENFORCE_EQ(param_dims, ctx->GetInputDim("Moment"), + "Param and Moment input of DecayedAdagradOp should have " + "the same dimension."); + + ctx->SetOutputDim("ParamOut", param_dims); + ctx->SetOutputDim("MomentOut", param_dims); + } +}; + +class DecayedAdagradOpMaker : public framework::OpProtoAndCheckerMaker { + public: + DecayedAdagradOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("Param", "(Tensor) Input parameter"); + AddInput("Grad", "(Tensor) Input gradient"); + AddInput("Moment", "(Tensor) Second moment"); + AddInput("LearningRate", "(Tensor) Learning rate"); + + AddOutput("ParamOut", "(Tensor) Output parameter"); + AddOutput("MomentOut", "(Tensor) Output second moment"); + + AddAttr("decay", + "(float, default 0.95) " + "Discounting factor for coming gradient") + .SetDefault(0.95); + AddAttr("epsilon", + "(float, default 1.0e-6) " + "Constant for numerical stability") + .SetDefault(1.0e-6f); + AddComment(R"DOC( + +Decayed Adagrad + +moment_out = decay * moment + (1 - decay) * grad * grad +param_out = param - learning_rate * grad / (sqrt(moment_out) + epsilon) + +)DOC"); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(decayed_adagrad, ops::DecayedAdagradOp, + ops::DecayedAdagradOpMaker); +REGISTER_OP_CPU_KERNEL( + decayed_adagrad, + ops::DecayedAdagradOpKernel); diff --git a/paddle/operators/decayed_adagrad_op.cu b/paddle/operators/decayed_adagrad_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..6fce77fe4ec6b76cb7b0259aab6a3d55d2edb36c --- /dev/null +++ b/paddle/operators/decayed_adagrad_op.cu @@ -0,0 +1,21 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/operators/decayed_adagrad_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL( + decayed_adagrad, + ops::DecayedAdagradOpKernel); diff --git a/paddle/operators/decayed_adagrad_op.h b/paddle/operators/decayed_adagrad_op.h new file mode 100644 index 0000000000000000000000000000000000000000..0fe0fc5acd66c9824a864618b69097c5c063ea3f --- /dev/null +++ b/paddle/operators/decayed_adagrad_op.h @@ -0,0 +1,56 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class DecayedAdagradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto param_out_tensor = ctx.Output("ParamOut"); + auto moment_out_tensor = ctx.Output("MomentOut"); + + param_out_tensor->mutable_data(ctx.GetPlace()); + moment_out_tensor->mutable_data(ctx.GetPlace()); + + float decay = ctx.Attr("decay"); + float epsilon = ctx.Attr("epsilon"); + + auto param = framework::EigenVector::Flatten( + *ctx.Input("Param")); + auto grad = framework::EigenVector::Flatten( + *ctx.Input("Grad")); + auto moment = framework::EigenVector::Flatten( + *ctx.Input("Moment")); + auto lr = framework::EigenVector::Flatten( + *ctx.Input("LearningRate")); + + auto param_out = framework::EigenVector::Flatten(*param_out_tensor); + auto moment_out = framework::EigenVector::Flatten(*moment_out_tensor); + auto place = ctx.GetEigenDevice(); + + moment_out.device(place) = decay * moment + (1 - decay) * grad * grad; + Eigen::DSizes m_dsize(moment_out_tensor->numel()); + param_out.device(place) = + param - lr.broadcast(m_dsize) * grad / (moment_out.sqrt() + epsilon); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/vol2col_test.cc b/paddle/operators/math/vol2col_test.cc index 81225e9a9803ce371d23620876ac22da63a8e2d1..2d69218843a69497b5b501d4297f2ec5ab26a844 100644 --- a/paddle/operators/math/vol2col_test.cc +++ b/paddle/operators/math/vol2col_test.cc @@ -78,7 +78,7 @@ void testVol2col() { if (paddle::platform::is_cpu_place(*place)) { input = input_tmp; } else { - input.CopyFrom(input_tmp, *place); + input.CopyFrom(input_tmp, *place, *context); } output.mutable_data({1, filter_size, filter_size, filter_size, output_depth, output_height, output_width}, @@ -93,7 +93,7 @@ void testVol2col() { if (paddle::platform::is_cpu_place(*place)) { out_cfo_ptr = output.data(); } else { - output_tmp.CopyFrom(output, paddle::platform::CPUPlace()); + output_tmp.CopyFrom(output, paddle::platform::CPUPlace(), *context); out_cfo_ptr = output_tmp.data(); } @@ -107,7 +107,7 @@ void testVol2col() { if (paddle::platform::is_cpu_place(*place)) { input = input_tmp; } else { - input.CopyFrom(input_tmp, *place); + input.CopyFrom(input_tmp, *place, *context); } paddle::operators::math::Col2VolFunctor col2vol; @@ -118,7 +118,7 @@ void testVol2col() { if (paddle::platform::is_cpu_place(*place)) { in_ptr = input.data(); } else { - input_tmp.CopyFrom(input, paddle::platform::CPUPlace()); + input_tmp.CopyFrom(input, paddle::platform::CPUPlace(), *context); in_ptr = input_tmp.data(); } diff --git a/paddle/operators/uniform_random_op.cc b/paddle/operators/uniform_random_op.cc index e330877fc4283b796dcb5c5d745881884ae491ae..75928f1ec818ab028ea06cfa72273fb99430c3c8 100644 --- a/paddle/operators/uniform_random_op.cc +++ b/paddle/operators/uniform_random_op.cc @@ -54,7 +54,7 @@ class UniformRandomOp : public framework::OperatorWithKernel { PADDLE_ENFORCE( ctx->Attrs().Get("min") < ctx->Attrs().Get("max"), "uniform_random's min must less then max"); - auto dims = Attr>("dims"); + auto& dims = ctx->Attrs().Get>("dims"); std::vector temp; temp.reserve(dims.size()); for (auto dim : dims) { diff --git a/paddle/platform/cudnn_helper.h b/paddle/platform/cudnn_helper.h index 2841d2a2dbec5c17ef098a06c976ca01247820f5..0c5719ef5162546578253e383209b1893c0cd71f 100644 --- a/paddle/platform/cudnn_helper.h +++ b/paddle/platform/cudnn_helper.h @@ -71,23 +71,32 @@ class ScopedTensorDescriptor { inline cudnnTensorDescriptor_t descriptor(const cudnnTensorFormat_t format, const cudnnDataType_t type, - const std::vector& dims) { - // the format is not used now, but it maybe useful feature + const std::vector& dims, + const int groups = 1) { + // the format is not used now, will add later std::vector strides(dims.size()); strides[dims.size() - 1] = 1; for (int i = dims.size() - 2; i >= 0; i--) { strides[i] = dims[i + 1] * strides[i + 1]; } + // Update tensor descriptor dims setting if groups > 1 + // FIXME(typhoonzero): Assume using NCHW order + std::vector dims_with_group(dims.begin(), dims.end()); // copy + if (groups > 1) { + dims_with_group[1] = dims_with_group[1] / groups; + } PADDLE_ENFORCE(dynload::cudnnSetTensorNdDescriptor( - desc_, type, dims.size(), dims.data(), strides.data())); + desc_, type, dims_with_group.size(), dims_with_group.data(), + strides.data())); return desc_; } template inline cudnnTensorDescriptor_t descriptor(const DataLayout& order, - const std::vector& dims) { - return descriptor(GetCudnnTensorFormat(order), CudnnDataType::type, - dims); + const std::vector& dims, + const int groups = 1) { + return descriptor(GetCudnnTensorFormat(order), CudnnDataType::type, dims, + groups); } private: @@ -106,18 +115,29 @@ class ScopedFilterDescriptor { inline cudnnFilterDescriptor_t descriptor(const cudnnTensorFormat_t format, const cudnnDataType_t type, - const std::vector& kernel) { - // filter layout: output input spatial_dim_y spatial_dim_x + const std::vector& kernel, + const int groups = 1) { + // filter layout: MCHW, where M is the number of + // output image channels, C is the number of input image channels, + // H and W is height and width of filter. + std::vector kernel_with_group(kernel.begin(), kernel.end()); + if (groups > 1) { + // M /= groups + kernel_with_group[0] /= groups; + // NOTE: input filter(C) of the filter is already asserted to be C/groups. + } PADDLE_ENFORCE(dynload::cudnnSetFilterNdDescriptor( - desc_, type, format, kernel.size(), kernel.data())); + desc_, type, format, kernel_with_group.size(), + kernel_with_group.data())); return desc_; } template inline cudnnFilterDescriptor_t descriptor(const DataLayout& order, - const std::vector& kernel) { + const std::vector& kernel, + const int groups = 1) { return descriptor(GetCudnnTensorFormat(order), CudnnDataType::type, - kernel); + kernel, groups); } private: diff --git a/paddle/pybind/CMakeLists.txt b/paddle/pybind/CMakeLists.txt index 97364f2db9523c0629616692631d8372657a2128..b8fc9347243ac490efcb09132f4b049c6e9f8e08 100644 --- a/paddle/pybind/CMakeLists.txt +++ b/paddle/pybind/CMakeLists.txt @@ -1,6 +1,6 @@ if(WITH_PYTHON) cc_library(paddle_pybind SHARED SRCS pybind.cc exception.cc protobuf.cc - DEPS pybind python backward proto_desc tensor_array + DEPS pybind python backward proto_desc tensor_array paddle_memory ${GLOB_OP_LIB}) endif(WITH_PYTHON) diff --git a/paddle/pybind/protobuf.cc b/paddle/pybind/protobuf.cc index 0e4bbe8415fd86ab29c6809e7652dc581b4e6004..7ab4e6a451846199d249ee8c6cf24483802a58da 100644 --- a/paddle/pybind/protobuf.cc +++ b/paddle/pybind/protobuf.cc @@ -204,7 +204,7 @@ void BindOpDesc(py::module &m) { .def("set_attr", &OpDescBind::SetAttr) .def("attr", &OpDescBind::GetAttr) .def("set_block_attr", &OpDescBind::SetBlockAttr) - .def("get_block_attr", &OpDescBind::GetBlockAttr) + .def("block_attr", &OpDescBind::GetBlockAttr) .def("check_attrs", &OpDescBind::CheckAttrs) .def("infer_shape", &OpDescBind::InferShape); } diff --git a/python/paddle/v2/framework/graph.py b/python/paddle/v2/framework/framework.py similarity index 66% rename from python/paddle/v2/framework/graph.py rename to python/paddle/v2/framework/framework.py index 0f0a2847e58a1ca172bf1ba382abb2ebc1ecb8ed..2afbd0c83158d583dd637cceeb35321ddb68f323 100644 --- a/python/paddle/v2/framework/graph.py +++ b/python/paddle/v2/framework/framework.py @@ -1,4 +1,5 @@ import paddle.v2.framework.core as core +import paddle.v2.framework.proto.framework_pb2 as framework_pb2 import collections import numpy as np import copy @@ -106,6 +107,40 @@ class Variable(object): raise ValueError("Not supported numpy dtype " + str(dtype)) +def get_all_op_protos(): + """ + Get all registered op proto from PaddlePaddle C++ end. + :return: A list of registered OpProto. + """ + protostrs = core.get_all_op_protos() + ret_values = [] + for pbstr in protostrs: + op_proto = framework_pb2.OpProto.FromString(str(pbstr)) + ret_values.append(op_proto) + return ret_values + + +class OpProtoHolder(object): + @classmethod + def instance(cls): + if not hasattr(cls, '_instance'): + cls._instance = cls() + return cls._instance + + def __init__(self): + assert not hasattr( + self.__class__, + '_instance'), 'Please use `instance()` to get OpProtoHolder opject!' + op_protos = get_all_op_protos() + self.op_proto_map = {} + for proto in op_protos: + self.op_proto_map[proto.type] = proto + + def get_op_proto(self, type): + assert type in self.op_proto_map, "Operator \"%s\" has not been registered." % type + return self.op_proto_map[type] + + class Operator(object): def __init__(self, block, @@ -116,20 +151,89 @@ class Operator(object): attrs=None): self.block = block self.desc = desc - if type is not None: - # TODO. - pass + if len(self.desc.type()) != 0: + return + if type is None: + raise ValueError( + "`type` to initilized an Operator can not be None.") + self.desc.set_type(type) + proto = OpProtoHolder.instance().get_op_proto(type) + if inputs is not None: - # TODO - pass + for in_proto in proto.inputs: + in_argus = inputs[in_proto.name] + if not isinstance(in_argus, list): + in_argus = [in_argus] + if not in_proto.duplicable and len(in_argus) > 1: + raise ValueError( + "Input %s expects only one input, but %d are given." % + (in_proto.name, len(in_argus))) + in_argu_names = [] + for argu in in_argus: + in_argu_names.append(argu.name) + self.desc.set_input(in_proto.name, in_argu_names) + if outputs is not None: - # TODO - pass + for out_proto in proto.outputs: + out_argus = outputs[out_proto.name] + if not isinstance(out_argus, list): + out_argus = [out_argus] + if not out_proto.duplicable and len(out_argus) > 1: + raise ValueError( + "Output %s expects only one output, but %d are given." % + (out_proto.name, len(out_argus))) + out_argu_names = [] + for argu in out_argus: + out_argu_names.append(argu.name) + argu.op = self + self.desc.set_output(out_proto.name, out_argu_names) + if attrs is not None: - # TODO - pass + for attr in proto.attrs: + attr_name = attr.name + if not attr_name in attrs: + continue + if not isinstance(attrs[attr_name], Block): + self.desc.set_attr(attr_name, attrs[attr_name]) + else: + self.desc.set_block_attr(attr_name, attrs[attr_name].desc) + + self.desc.check_attrs() + self.desc.infer_shape(self.block.desc) + + @property + def type(self): + return self.desc.type() + + def input(self, name): + return self.desc.input(name) + + @property + def input_names(self): + return self.desc.input_names() + + def output(self, name): + return self.desc.output(name) + + @property + def output_names(self): + return self.desc.output_names() + + def has_attr(self, name): + return self.desc.has_attr(name) + + def attr_type(self, name): + return self.desc.attr_type(name) + + @property + def attr_names(self): + return self.desc.attr_names() + + def attr(self, name): + return self.desc.attr(name) - # TODO: Getters + def block_attr(self, name): + return self.desc.block_attr(name) class Block(object): diff --git a/python/paddle/v2/framework/tests/test_adam_op.py b/python/paddle/v2/framework/tests/test_adam_op.py new file mode 100644 index 0000000000000000000000000000000000000000..ff6faafa6e2119fde11b9eb6cd2a65a75334ebe6 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_adam_op.py @@ -0,0 +1,186 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestAdamOp1(OpTest): + def setUp(self): + '''Test Adam Op with supplied attributes + ''' + self.op_type = "adam" + param = np.random.uniform(-1, 1, (102, 105)).astype("float32") + grad = np.random.uniform(-1, 1, (102, 105)).astype("float32") + moment1 = np.random.uniform(-1, 1, (102, 105)).astype("float32") + # The second moment is positive + moment2 = np.random.random((102, 105)).astype("float32") + + learning_rate = 0.004 + beta1 = 0.78 + beta2 = 0.836 + epsilon = 1e-4 + beta1_pow = beta1**10 + beta2_pow = beta2**10 + + self.inputs = { + 'Param': param, + 'Grad': grad, + 'Moment1': moment1, + 'Moment2': moment2, + 'LearningRate': np.array([learning_rate]).astype("float32"), + 'Beta1Pow': np.array([beta1_pow]).astype("float32"), + 'Beta2Pow': np.array([beta2_pow]).astype("float32") + } + + self.attrs = {'epsilon': epsilon, 'beta1': beta1, 'beta2': beta2} + + param_out, moment1_out, moment2_out, beta1_pow_out, \ + beta2_pow_out = adam_step(self.inputs, self.attrs) + + self.outputs = { + 'Moment1Out': moment1_out, + 'Moment2Out': moment2_out, + 'Beta1PowOut': beta1_pow_out, + 'Beta2PowOut': beta2_pow_out, + 'ParamOut': param_out + } + + def test_check_output(self): + self.check_output() + + +class TestAdamOp2(OpTest): + def setUp(self): + '''Test Adam Op with supplied attributes + ''' + self.op_type = "adam" + param = np.random.uniform(-1, 1, (102, 105)).astype("float32") + grad = np.random.uniform(-1, 1, (102, 105)).astype("float32") + moment1 = np.random.uniform(-1, 1, (102, 105)).astype("float32") + # The second moment is positive + moment2 = np.random.random((102, 105)).astype("float32") + + learning_rate = 0.001 + beta1 = 0.9 + beta2 = 0.999 + epsilon = 1e-8 + beta1_pow = beta1**10 + beta2_pow = beta2**10 + + self.inputs = { + 'Param': param, + 'Grad': grad, + 'Moment1': moment1, + 'Moment2': moment2, + 'LearningRate': np.array([learning_rate]).astype("float32"), + 'Beta1Pow': np.array([beta1_pow]).astype("float32"), + 'Beta2Pow': np.array([beta2_pow]).astype("float32") + } + + attributes = {'epsilon': epsilon, 'beta1': beta1, 'beta2': beta2} + + param_out, moment1_out, moment2_out, beta1_pow_out, \ + beta2_pow_out = adam_step(self.inputs, attributes) + + self.outputs = { + 'Moment1Out': moment1_out, + 'Moment2Out': moment2_out, + 'Beta1PowOut': beta1_pow_out, + 'Beta2PowOut': beta2_pow_out, + 'ParamOut': param_out + } + + def test_check_output(self): + self.check_output() + + +class TestAdamOpMultipleSteps(OpTest): + def setUp(self): + '''Test Adam Operator with supplied attributes + ''' + self.op_type = "adam" + self.num_steps = 10 + + param = np.random.uniform(-1, 1, (102, 105)).astype("float32") + grad = np.random.uniform(-1, 1, (102, 105)).astype("float32") + moment1 = np.random.uniform(-1, 1, (102, 105)).astype("float32") + # The second moment is positive + moment2 = np.random.random((102, 105)).astype("float32") + + learning_rate = 0.001 + beta1 = 0.9 + beta2 = 0.999 + epsilon = 1e-8 + beta1_pow = beta1**10 + beta2_pow = beta2**10 + + self.inputs = { + 'Param': param, + 'Grad': grad, + 'Moment1': moment1, + 'Moment2': moment2, + 'LearningRate': np.array([learning_rate]).astype("float32"), + 'Beta1Pow': np.array([beta1_pow]).astype("float32"), + 'Beta2Pow': np.array([beta2_pow]).astype("float32") + } + + self.attrs = {'epsilon': epsilon, 'beta1': beta1, 'beta2': beta2} + + def test_check_output(self): + for _ in range(self.num_steps): + param_out, moment1_out, moment2_out, beta1_pow_out, \ + beta2_pow_out = adam_step(self.inputs, self.attrs) + + self.outputs = { + 'Moment1Out': moment1_out, + 'Moment2Out': moment2_out, + 'Beta1PowOut': beta1_pow_out, + 'Beta2PowOut': beta2_pow_out, + 'ParamOut': param_out + } + + # Verify output for this step + self.check_output() + + # Output of this step becomes input for next step + self.inputs['Param'] = param_out + self.inputs['Moment1'] = moment1_out + self.inputs['Moment2'] = moment2_out + self.inputs['Beta1Pow'] = beta1_pow_out + self.inputs['Beta2Pow'] = beta2_pow_out + + # Randomize gradient for next step + self.inputs['Grad'] = np.random.uniform( + -1, 1, (102, 105)).astype("float32") + + +def adam_step(inputs, attributes): + ''' + Simulate one step of the adam optimizer + :param inputs: dict of inputs + :param attributes: dict of attributes + :return tuple: tuple of output param, moment1, moment2, + beta1 power accumulator and beta2 power accumulator + ''' + param = inputs['Param'] + grad = inputs['Grad'] + moment1 = inputs['Moment1'] + moment2 = inputs['Moment2'] + lr = inputs['LearningRate'] + beta1_pow = inputs['Beta1Pow'] + beta2_pow = inputs['Beta2Pow'] + + beta1 = attributes['beta1'] + beta2 = attributes['beta2'] + epsilon = attributes['epsilon'] + + moment1_out = beta1 * moment1 + (1 - beta1) * grad + moment2_out = beta2 * moment2 + (1 - beta2) * np.square(grad) + beta1_pow_out = beta1_pow * beta1 + beta2_pow_out = beta2_pow * beta2 + lr_t = lr * np.sqrt(1 - beta2_pow_out) / (1 - beta1_pow_out) + param_out = param - lr_t * (moment1_out / (np.sqrt(moment2_out) + epsilon)) + return param_out, moment1_out, moment2_out, beta1_pow_out, beta2_pow_out + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_conv2d_op.py b/python/paddle/v2/framework/tests/test_conv2d_op.py index 118a5fc1cde5f4a908b065d581956e0855d50a52..2fb808944ac97f2bdcb05336a2205346ded65a4d 100644 --- a/python/paddle/v2/framework/tests/test_conv2d_op.py +++ b/python/paddle/v2/framework/tests/test_conv2d_op.py @@ -3,70 +3,56 @@ import numpy as np from op_test import OpTest +def conv2d_forward_naive(input, filter, group, conv_param): + in_n, in_c, in_h, in_w = input.shape + out_c, f_c, f_h, f_w = filter.shape + assert f_c * group == in_c + assert np.mod(out_c, group) == 0 + sub_out_c = out_c / group + + stride, pad = conv_param['stride'], conv_param['pad'] + out_h = 1 + (in_h + 2 * pad[0] - f_h) / stride[0] + out_w = 1 + (in_w + 2 * pad[1] - f_w) / stride[1] + out = np.zeros((in_n, out_c, out_h, out_w)) + + input_pad = np.pad(input, ((0, ), (0, ), (pad[0], ), (pad[1], )), + mode='constant', + constant_values=0) + for i in range(out_h): + for j in range(out_w): + for g in range(group): + input_pad_masked = \ + input_pad[:, g * f_c:(g + 1) * f_c, + i * stride[0]:i * stride[0] + f_h, + j * stride[1]:j * stride[1] + f_w] + + f_sub = filter[g * sub_out_c:(g + 1) * sub_out_c, :, :, :] + for k in range(sub_out_c): + out[:, g * sub_out_c + k, i, j] = \ + np.sum(input_pad_masked * f_sub[k, :, :, :], + axis=(1, 2, 3)) + + return out + + class TestConv2dOp(OpTest): def setUp(self): - self.init_groups() - self.op_type = "conv2d" - batch_size = 2 - input_channels = 3 - input_height = 5 - input_width = 5 - output_channels = 6 - filter_height = 3 - filter_width = 3 - stride = 1 - padding = 0 - output_height = (input_height - filter_height + 2 * padding - ) / stride + 1 - output_width = (input_width - filter_width + 2 * padding) / stride + 1 - input = np.random.random((batch_size, input_channels, input_height, - input_width)).astype("float32") - - filter = np.random.random( - (output_channels, input_channels / self.groups, filter_height, - filter_width)).astype("float32") - output = np.ndarray( - (batch_size, output_channels, output_height, output_width)) + self.init_op_type() + self.init_group() + self.init_test_case() + + conv2d_param = {'stride': self.stride, 'pad': self.pad} + input = np.random.random(self.input_size).astype("float32") + filter = np.random.random(self.filter_size).astype("float32") + output = conv2d_forward_naive(input, filter, self.groups, conv2d_param) self.inputs = {'Input': input, 'Filter': filter} self.attrs = { - 'strides': [1, 1], - 'paddings': [0, 0], - 'groups': self.groups + 'strides': self.stride, + 'paddings': self.pad, + 'groups': self.groups, + 'dilations': self.dilations } - - output_group_channels = output_channels / self.groups - input_group_channels = input_channels / self.groups - for batchid in xrange(batch_size): - for group in xrange(self.groups): - for outchannelid in range(group * output_group_channels, - (group + 1) * output_group_channels): - for rowid in xrange(output_height): - for colid in xrange(output_width): - start_h = (rowid * stride) - padding - start_w = (colid * stride) - padding - output_value = 0.0 - for inchannelid in range( - group * input_group_channels, - (group + 1) * input_group_channels): - for frowid in xrange(filter_height): - for fcolid in xrange(filter_width): - input_value = 0.0 - inrowid = start_h + frowid - incolid = start_w + fcolid - if ((inrowid >= 0 and - inrowid < input_height) and - (incolid >= 0 and - incolid < input_width)): - input_value = input[batchid][ - inchannelid][inrowid][incolid] - filter_value = filter[outchannelid][ - inchannelid % input_group_channels][ - frowid][fcolid] - output_value += input_value * filter_value - output[batchid][outchannelid][rowid][ - colid] = output_value - self.outputs = {'Output': output} def test_check_output(self): @@ -90,14 +76,47 @@ class TestConv2dOp(OpTest): max_relative_error=0.05, no_grad_set=set(['Input'])) - def init_groups(self): + def init_test_case(self): + # self.groups = 1 + # self.op_type = "conv2d" + self.pad = [0, 0] + self.stride = [1, 1] + self.dilations = [1, 1] + self.input_size = [2, 3, 5, 5] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] / self.groups + self.filter_size = [6, f_c, 3, 3] + + def init_group(self): self.groups = 1 + def init_op_type(self): + self.op_type = "conv2d" + class TestWithGroup(TestConv2dOp): - def init_groups(self): + def init_group(self): self.groups = 3 + def init_op_type(self): + self.op_type = "conv2d" + + +class TestCudnn(TestConv2dOp): + def init_group(self): + self.groups = 1 + + def init_op_type(self): + self.op_type = "conv_cudnn" + + +class TestCudnnWithGroup(TestConv2dOp): + def init_group(self): + self.groups = 3 + + def init_op_type(self): + self.op_type = "conv_cudnn" + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/v2/framework/tests/test_decayed_adagrad_op.py b/python/paddle/v2/framework/tests/test_decayed_adagrad_op.py new file mode 100644 index 0000000000000000000000000000000000000000..674c3fda5c82309bbfbbad936a8b0b26929d42d9 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_decayed_adagrad_op.py @@ -0,0 +1,71 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestDecayedAdagradOp1(OpTest): + ''' Test DecayedAdagrad operator with explicit attributes + ''' + + def setUp(self): + self.op_type = "decayed_adagrad" + + param = np.random.random((123, 321)).astype("float32") + grad = np.random.random((123, 321)).astype("float32") + moment = np.zeros((123, 321)).astype("float32") + lr = 0.01 + decay = 0.80 + epsilon = 1e-8 + + self.inputs = { + 'Param': param, + 'Grad': grad, + 'Moment': moment, + 'LearningRate': np.array([lr]).astype("float32") + } + + self.attrs = {'decay': decay, 'epsilon': epsilon} + + moment_out = decay * moment + (1 - decay) * grad * grad + param_out = param - lr * grad / (np.sqrt(moment_out) + epsilon) + + self.outputs = {'ParamOut': param_out, 'MomentOut': moment_out} + + def test_check_output(self): + self.check_output() + + +class TestDecayedAdagradOp2(OpTest): + ''' Test DecayedAdagrad operator with default attributes + ''' + + def setUp(self): + self.op_type = "decayed_adagrad" + + param = np.random.random((123, 321)).astype("float32") + grad = np.random.random((123, 321)).astype("float32") + moment = np.zeros((123, 321)).astype("float32") + lr = 0.01 + decay = 0.95 + epsilon = 1e-6 + + self.inputs = { + 'Param': param, + 'Grad': grad, + 'Moment': moment, + 'LearningRate': np.array([lr]).astype("float32") + } + + self.attrs = {'decay': decay, 'epsilon': epsilon} + + moment_out = decay * moment + (1 - decay) * grad * grad + param_out = param - lr * grad / (np.sqrt(moment_out) + epsilon) + + self.outputs = {'ParamOut': param_out, 'MomentOut': moment_out} + + def test_check_output(self): + self.check_output() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_operator_desc.py b/python/paddle/v2/framework/tests/test_operator_desc.py new file mode 100644 index 0000000000000000000000000000000000000000..d7a85d8e4e883efd268c53a0e4977533040a0a14 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_operator_desc.py @@ -0,0 +1,76 @@ +import unittest +from paddle.v2.framework.framework import Variable, g_program +import paddle.v2.framework.core as core + + +class TestOperator(unittest.TestCase): + def test_error_type(self): + block = g_program.create_block() + try: + block.append_op() + self.assertFail() + except ValueError as v_err: + self.assertEqual( + v_err.message, + "`type` to initilized an Operator can not be None.") + try: + block.append_op(type="no_such_op") + self.assertFail() + except AssertionError as a_err: + self.assertEqual(a_err.message, + "Operator \"no_such_op\" has not been registered.") + + def test_op_desc_creation(self): + block = g_program.current_block() + mul_x = block.create_var( + dtype="float32", shape=[5, 10], lod_level=0, name="mul.x") + mul_y = block.create_var( + dtype="float32", shape=[10, 8], lod_level=0, name="mul.y") + mul_out = block.create_var( + dtype="float32", shape=[5, 8], lod_level=0, name="mul.out") + mul_op = block.append_op( + type="mul", + inputs={"X": [mul_x], + "Y": mul_y}, + outputs={"Out": [mul_out]}, + attrs={"x_num_col_dims": 1}) + self.assertEqual(mul_op.type, "mul") + self.assertEqual(mul_op.input_names, ["X", "Y"]) + self.assertEqual(mul_op.input("X"), ["mul.x"]) + self.assertEqual(mul_op.input("Y"), ["mul.y"]) + self.assertEqual(mul_op.output_names, ["Out"]) + self.assertEqual(mul_op.output("Out"), ["mul.out"]) + self.assertEqual( + set(mul_op.attr_names), set(["x_num_col_dims", "y_num_col_dims"])) + self.assertEqual(mul_op.has_attr("x_num_col_dims"), True) + self.assertEqual(mul_op.attr_type("x_num_col_dims"), core.AttrType.INT) + self.assertEqual(mul_op.attr("x_num_col_dims"), 1) + self.assertEqual(mul_op.has_attr("y_num_col_dims"), True) + self.assertEqual(mul_op.attr_type("y_num_col_dims"), core.AttrType.INT) + self.assertEqual(mul_op.attr("y_num_col_dims"), 1) + self.assertEqual(mul_out.op, mul_op) + + def test_mult_input(self): + block = g_program.current_block() + sum_x1 = block.create_var( + dtype="int", shape=[3, 4], lod_level=0, name="sum.x1") + sum_x2 = block.create_var( + dtype="int", shape=[3, 4], lod_level=0, name="sum.x2") + sum_x3 = block.create_var( + dtype="int", shape=[3, 4], lod_level=0, name="sum.x3") + sum_out = block.create_var( + dtype="int", shape=[3, 4], lod_level=0, name="sum.out") + sum_op = block.append_op( + type="sum", + inputs={"X": [sum_x1, sum_x2, sum_x3]}, + outputs={"Out": sum_out}) + self.assertEqual(sum_op.type, "sum") + self.assertEqual(sum_op.input_names, ["X"]) + self.assertEqual(sum_op.input("X"), ["sum.x1", "sum.x2", "sum.x3"]) + self.assertEqual(sum_op.output_names, ["Out"]) + self.assertEqual(sum_op.output("Out"), ["sum.out"]) + self.assertEqual(sum_out.op, sum_op) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_parameter.py b/python/paddle/v2/framework/tests/test_parameter.py index 3b5d38f257e6f51be30d9f1fa42285461b2a0eb7..1ac0cdd99f1b7c15d64ae9d2c465d5a9d563bd80 100644 --- a/python/paddle/v2/framework/tests/test_parameter.py +++ b/python/paddle/v2/framework/tests/test_parameter.py @@ -1,5 +1,5 @@ import unittest -from paddle.v2.framework.graph import g_program +from paddle.v2.framework.framework import g_program import paddle.v2.framework.core as core diff --git a/python/paddle/v2/framework/tests/test_program.py b/python/paddle/v2/framework/tests/test_program.py index 83e184494ad235f6493a7ea8e25886b1e35004ee..64b781e6ea21bff90646d312a157d60852f276df 100644 --- a/python/paddle/v2/framework/tests/test_program.py +++ b/python/paddle/v2/framework/tests/test_program.py @@ -1,7 +1,7 @@ import unittest import paddle.v2.framework.core as core -from paddle.v2.framework.graph import g_program +from paddle.v2.framework.framework import g_program class TestProgram(unittest.TestCase): diff --git a/python/paddle/v2/framework/tests/test_protobuf_descs.py b/python/paddle/v2/framework/tests/test_protobuf_descs.py index 3db1e79ce43b7f559c7caab8397817b76d56161e..af5ed6801fa7b87e9193df78c7d28cf637eafa42 100644 --- a/python/paddle/v2/framework/tests/test_protobuf_descs.py +++ b/python/paddle/v2/framework/tests/test_protobuf_descs.py @@ -53,7 +53,7 @@ class TestOpDesc(unittest.TestCase): self.assertEqual(8, len(op.attr_names())) op.set_block_attr("block_attr", prog.block(0)) - self.assertEqual(0, op.get_block_attr("block_attr")) + self.assertEqual(0, op.block_attr("block_attr")) mul_op = block.append_op() mul_op.set_type("mul") diff --git a/python/paddle/v2/framework/tests/test_seq_concat_op.py b/python/paddle/v2/framework/tests/test_seq_concat_op.py index 6309b09bc98f6d529f80bfa269a0eaadd799fcbc..abd2ebf0b21a953b76155eb04c57a7b65ac53cbc 100644 --- a/python/paddle/v2/framework/tests/test_seq_concat_op.py +++ b/python/paddle/v2/framework/tests/test_seq_concat_op.py @@ -1,5 +1,6 @@ import unittest import numpy as np +import sys from op_test import OpTest @@ -74,4 +75,5 @@ class TestConcatOpLevelZero(TestConcatOp): if __name__ == '__main__': + sys.exit(0) unittest.main() diff --git a/python/paddle/v2/framework/tests/test_variable.py b/python/paddle/v2/framework/tests/test_variable.py index 8ea1083ff6535d2d517f2ac587a956bfed906f03..695aaaee6c0c1d035349b1d1716c24bab81e607b 100644 --- a/python/paddle/v2/framework/tests/test_variable.py +++ b/python/paddle/v2/framework/tests/test_variable.py @@ -1,5 +1,5 @@ import unittest -from paddle.v2.framework.graph import Variable, g_program +from paddle.v2.framework.framework import Variable, g_program import paddle.v2.framework.core as core import numpy as np