diff --git a/CMakeLists.txt b/CMakeLists.txt index dcff6b54cafce35846627e78cfcdac65fae7e686..2a6b0a20e441676c85c9ed8f8ad1a6e7abdf1ea8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -13,7 +13,6 @@ # limitations under the License cmake_minimum_required(VERSION 3.0) - set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake") set(PROJ_ROOT ${CMAKE_CURRENT_SOURCE_DIR}) set(PROJ_BINARY_ROOT ${CMAKE_CURRENT_BINARY_DIR}) diff --git a/cmake/generic.cmake b/cmake/generic.cmake index e42e75c12ab1e5133f5ecbdb90ef26e3f8df5133..534be0abe246ac70950d85ad05441825c8ca768a 100644 --- a/cmake/generic.cmake +++ b/cmake/generic.cmake @@ -290,8 +290,22 @@ function(go_library TARGET_NAME) set(${TARGET_NAME}_LIB_NAME "${CMAKE_STATIC_LIBRARY_PREFIX}${TARGET_NAME}${CMAKE_STATIC_LIBRARY_SUFFIX}" CACHE STRING "output library name for target ${TARGET_NAME}") endif() - # Add dummy code to support `make target_name` under Terminal Command set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}_dummy.c) + + # This custom command will always run since it depends on a not + # existing file. + add_custom_command( + OUTPUT dummy_rebulid_${TARGET_NAME} + COMMAND cmake -E touch ${dummyfile} + ) + # Create a custom target that depends on the custom command output + # file, so the custom command can be referenced as a dependency by + # `add_dependencies`. + add_custom_target(rebuild_${TARGET_NAME} + DEPENDS dummy_rebulid_${TARGET_NAME} + ) + + # Add dummy code to support `make target_name` under Terminal Command file(WRITE ${dummyfile} "const char * dummy = \"${dummyfile}\";") if (go_library_SHARED OR go_library_shared) add_library(${TARGET_NAME} SHARED ${dummyfile}) @@ -302,6 +316,12 @@ function(go_library TARGET_NAME) add_dependencies(${TARGET_NAME} ${go_library_DEPS}) endif(go_library_DEPS) + # The "source file" of the library is `${dummyfile}` which never + # change, so the target will never rebuild. Make the target depends + # on the custom command that touches the library "source file", so + # rebuild will always happen. + add_dependencies(${TARGET_NAME} rebuild_${TARGET_NAME}) + set(${TARGET_NAME}_LIB_PATH "${CMAKE_CURRENT_BINARY_DIR}/${${TARGET_NAME}_LIB_NAME}" CACHE STRING "output library path for target ${TARGET_NAME}") file(GLOB GO_SOURCE RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.go") diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index eb3416462324edf6f6e76e32d7400d1fd774b9bd..760d84e51e7473d359a415e4790251db3d139ab2 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -1,23 +1,25 @@ # ddim lib -cc_library(enforce SRCS enforce.cc DEPS glog) -cc_test(enforce_test SRCS enforce_test.cc DEPS enforce) cc_library(ddim SRCS ddim.cc DEPS eigen3) cc_test(ddim_test SRCS ddim_test.cc DEPS ddim) nv_test(dim_test SRCS dim_test.cu DEPS ddim) -cc_library(tensor SRCS tensor.cc DEPS ddim place enforce paddle_memory) + +cc_library(tensor SRCS tensor.cc DEPS ddim place paddle_memory) cc_test(tensor_test SRCS tensor_test.cc DEPS tensor) +cc_test(eigen_test SRCS eigen_test.cc DEPS tensor) + cc_test(variable_test SRCS variable_test.cc) cc_test(scope_test SRCS scope_test.cc) + proto_library(attr_type SRCS attr_type.proto) proto_library(op_proto SRCS op_proto.proto DEPS attr_type) -cc_test(op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf) proto_library(op_desc SRCS op_desc.proto DEPS attr_type) +cc_test(op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf) cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf) cc_library(operator SRCS operator.cc DEPS op_desc device_context tensor) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry) -cc_library(op_registry SRCS op_registry.cc DEPS op_proto op_desc enforce) +cc_library(op_registry SRCS op_registry.cc DEPS op_proto op_desc) cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry operator) py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto) diff --git a/paddle/framework/attr_checker.h b/paddle/framework/attr_checker.h index c0c33d81149ac2fc2a9a57d90931ef32375fe1d0..ea5614a45f3a77a851358aff80abbc276c9972ba 100644 --- a/paddle/framework/attr_checker.h +++ b/paddle/framework/attr_checker.h @@ -4,8 +4,9 @@ #include #include #include +#include #include -#include "paddle/framework/enforce.h" +#include "paddle/platform/enforce.h" namespace paddle { namespace framework { @@ -41,6 +42,35 @@ class DefaultValueSetter { T default_value_; }; +template +class EnumInContainer { + public: + explicit EnumInContainer(const std::unordered_set& c) : container_(c) {} + void operator()(T& val) const { + PADDLE_ENFORCE(container_.find(val) != container_.end(), + "Value %s is not in enum container %s", val, + ContainerDebugString()); + } + + private: + std::string ContainerDebugString() const { + std::ostringstream sout; + sout << "["; + size_t cnt = 0; + for (auto& v : container_) { + sout << v; + ++cnt; + if (cnt != container_.size()) { + sout << " ,"; + } + } + sout << "]"; + return sout.str(); + } + + std::unordered_set container_; +}; + // check whether a certain attribute fit its limits // an attribute can have more than one limits template @@ -50,6 +80,11 @@ class TypedAttrChecker { public: TypedAttrChecker(const std::string& attr_name) : attr_name_(attr_name) {} + TypedAttrChecker& InEnum(const std::unordered_set& range) { + value_checkers_.push_back(EnumInContainer(range)); + return *this; + } + TypedAttrChecker& LargerThan(const T& lower_bound) { value_checkers_.push_back(LargerThanChecker(lower_bound)); return *this; diff --git a/paddle/framework/ddim.cc b/paddle/framework/ddim.cc index d2ef85afe55e640a17b8c957bac61d175e69ff3f..545c1dcc2a1682839d90194002fdbb748d85e808 100644 --- a/paddle/framework/ddim.cc +++ b/paddle/framework/ddim.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/framework/ddim.h" -#include "paddle/framework/enforce.h" +#include "paddle/platform/enforce.h" namespace paddle { namespace framework { diff --git a/paddle/framework/ddim.h b/paddle/framework/ddim.h index 070850375d1bd3a61b98184495c979573bf9542c..9fcc657edcd5459d0a42a64d708603a4bcd53cf0 100644 --- a/paddle/framework/ddim.h +++ b/paddle/framework/ddim.h @@ -19,7 +19,7 @@ limitations under the License. */ #include #include #include "paddle/framework/dim.h" -#include "paddle/framework/enforce.h" +#include "paddle/platform/enforce.h" #include "unsupported/Eigen/CXX11/Tensor" namespace paddle { @@ -119,17 +119,6 @@ int arity(const DDim& ddim); std::ostream& operator<<(std::ostream&, const DDim&); -template -Eigen::DSizes ToEigenDSizes(const DDim& dims) { - int rank = arity(dims); - PADDLE_ENFORCE(rank == NDIMS, "DDim and NDIMS must be same"); - Eigen::DSizes dsizes; - for (int d = 0; d < rank; d++) { - dsizes[d] = dims[d]; - } - return dsizes; -} - } // namespace framework } // namespace paddle diff --git a/paddle/framework/eigen.h b/paddle/framework/eigen.h new file mode 100644 index 0000000000000000000000000000000000000000..4ba4fd4d110330805faf2468bd406cb23c6f1b1c --- /dev/null +++ b/paddle/framework/eigen.h @@ -0,0 +1,84 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/tensor.h" +#include "unsupported/Eigen/CXX11/Tensor" + +namespace paddle { +namespace framework { + +// EigenDim converts paddle::platform::DDim into Eigen::DSizes. +template +struct EigenDim { + using Type = Eigen::DSizes; + + static Type From(const DDim& dims) { + PADDLE_ENFORCE(arity(dims) == D, "D must match arity(DDim)"); + Type ret; + for (int d = 0; d < arity(dims); d++) { + ret[d] = dims[d]; + } + return ret; + } +}; + +// Interpret paddle::platform::Tensor as EigenTensor and EigenConstTensor. +template +struct EigenTensor { + // TODO(qijun) Now, default type in unaligned, and we will make a benchmark on + // the speed of aligned and unaligned version in future. + using Type = Eigen::TensorMap>; + + using ConstType = + Eigen::TensorMap>; + + static Type From(Tensor& tensor, DDim dims) { + return Type(tensor.data(), EigenDim::From(dims)); + } + + static Type From(Tensor& tensor) { return From(tensor, tensor.dims_); } + + static ConstType From(const Tensor& tensor, DDim dims) { + return ConstType(tensor.data(), EigenDim::From(dims)); + } + + static ConstType From(const Tensor& tensor) { + return From(tensor, tensor.dims_); + } +}; + +template +struct EigenVector : public EigenTensor { + // Flatten is to reshape a Tensor into a one dimension EigenVector + static typename EigenTensor::Type Flatten(Tensor& tensor) { + return EigenTensor::From( + tensor, make_ddim({static_cast(product(tensor.dims_))})); + } + + static typename EigenTensor::ConstType Flatten(const Tensor& tensor) { + return EigenTensor::From( + tensor, make_ddim({static_cast(product(tensor.dims_))})); + } +}; + +template +using EigenMatrix = EigenTensor; + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/eigen_test.cc b/paddle/framework/eigen_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..a9fa728e49a0dcc781e520a22c1ee5f921c4c733 --- /dev/null +++ b/paddle/framework/eigen_test.cc @@ -0,0 +1,101 @@ +/* + Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +#include "paddle/framework/eigen.h" +#include + +namespace paddle { +namespace framework { + +TEST(EigenDim, From) { + EigenDim<3>::Type ed = EigenDim<3>::From(make_ddim({1, 2, 3})); + ASSERT_EQ(1, ed[0]); + ASSERT_EQ(2, ed[1]); + ASSERT_EQ(3, ed[2]); +} + +TEST(Eigen, Tensor) { + Tensor t; + float* p = t.mutable_data(make_ddim({1, 2, 3}), platform::CPUPlace()); + for (int i = 0; i < 1 * 2 * 3; i++) { + p[i] = static_cast(i); + } + + EigenTensor::Type et = EigenTensor::From(t); + + ASSERT_EQ(1, et.dimension(0)); + ASSERT_EQ(2, et.dimension(1)); + ASSERT_EQ(3, et.dimension(2)); + + for (int i = 0; i < 1; i++) { + for (int j = 0; j < 2; j++) { + for (int k = 0; k < 3; k++) { + ASSERT_NEAR((i * 2 + j) * 3 + k, et(i, j, k), 1e-6f); + } + } + } +} + +TEST(Eigen, VectorFrom) { + Tensor t; + float* p = t.mutable_data(make_ddim({6}), platform::CPUPlace()); + for (int i = 0; i < 6; i++) { + p[i] = static_cast(i); + } + + EigenVector::Type ev = EigenVector::From(t); + + ASSERT_EQ(6, ev.dimension(0)); + + for (int i = 0; i < 6; i++) { + ASSERT_NEAR(i, ev(i), 1e-6f); + } +} + +TEST(Eigen, VectorFlatten) { + Tensor t; + float* p = t.mutable_data(make_ddim({1, 2, 3}), platform::CPUPlace()); + for (int i = 0; i < 1 * 2 * 3; i++) { + p[i] = static_cast(i); + } + + EigenVector::Type ev = EigenVector::Flatten(t); + + ASSERT_EQ(1 * 2 * 3, ev.dimension(0)); + + for (int i = 0; i < 1 * 2 * 3; i++) { + ASSERT_NEAR(i, ev(i), 1e-6f); + } +} + +TEST(Eigen, Matrix) { + Tensor t; + float* p = t.mutable_data(make_ddim({2, 3}), platform::CPUPlace()); + for (int i = 0; i < 2 * 3; i++) { + p[i] = static_cast(i); + } + + EigenMatrix::Type em = EigenMatrix::From(t); + + ASSERT_EQ(2, em.dimension(0)); + ASSERT_EQ(3, em.dimension(1)); + + for (int i = 0; i < 2; i++) { + for (int j = 0; j < 3; j++) { + ASSERT_NEAR(i * 3 + j, em(i, j), 1e-6f); + } + } +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/enforce.cc b/paddle/framework/enforce.cc deleted file mode 100644 index 644930ff989bb8935f37642c117084f580379bd7..0000000000000000000000000000000000000000 --- a/paddle/framework/enforce.cc +++ /dev/null @@ -1,15 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ - -#include "paddle/framework/enforce.h" diff --git a/paddle/framework/enforce.h b/paddle/framework/enforce.h deleted file mode 100644 index ffce8148e9516a5720757c87685ff6bd2937977c..0000000000000000000000000000000000000000 --- a/paddle/framework/enforce.h +++ /dev/null @@ -1,75 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include -#include -#include -#include - -namespace paddle { -namespace framework { - -/** - * @brief Enforce exception. Inherits std::exception - * - * All enforce condition not met, will throw an EnforceNotMet exception. - */ -class EnforceNotMet : public std::exception { - public: - EnforceNotMet(const std::string& msg, const char* file, int fileline) { - std::ostringstream sout; - sout << msg << " at [" << file << ":" << fileline << "];"; - all_msg_ = sout.str(); - } - - const char* what() const noexcept override { return all_msg_.c_str(); } - - private: - std::string all_msg_; -}; - -// From https://stackoverflow.com/questions/30130930/ -// __buildin_expect is in C++ 11 standard. Since the condition which enforced -// should be true in most situation, it will make the compiler generate faster -// code by adding `UNLIKELY` macro. -#define UNLIKELY(condition) __builtin_expect(static_cast(condition), 0) - -/** - * @brief Throw a EnforceNotMet exception, automatically filled __FILE__ & - * __LINE__ - * - * This macro take __VA_ARGS__, user can pass any type if that type can - * serialize to std::ostream - */ -#define PADDLE_THROW(...) \ - do { \ - throw ::paddle::framework::EnforceNotMet( \ - ::paddle::string::Sprintf(__VA_ARGS__), __FILE__, __LINE__); \ - } while (0) - -/** - * @brief Enforce a condition, otherwise throw an EnforceNotMet - */ -#ifdef NDEBUG -#define PADDLE_ENFORCE(condition, ...) \ - do { \ - if (UNLIKELY(!(condition))) { \ - PADDLE_THROW(__VA_ARGS__); \ - } \ - } while (0) -#else -#define PADDLE_ENFORCE(condition, ...) \ - CHECK(condition) << ::paddle::string::Sprintf(__VA_ARGS__); -#endif - -} // namespace framework -} // namespace paddle diff --git a/paddle/framework/net.cc b/paddle/framework/net.cc index b9cd732d409e6b8ab6bdddcef810597ac28fba1d..501536657d76cc50b1cc4104007edd4b47758aea 100644 --- a/paddle/framework/net.cc +++ b/paddle/framework/net.cc @@ -19,7 +19,10 @@ namespace paddle { namespace framework { -void PlainNet::CompleteAddOp() { +void PlainNet::CompleteAddOp(bool calc) { + add_op_done_ = true; + if (!calc) return; + std::unordered_set input_set; std::unordered_set output_set; std::unordered_set temp_output; @@ -52,7 +55,6 @@ void PlainNet::CompleteAddOp() { } attrs_["temporary_index"] = tmp_index; - add_op_done_ = true; } std::string PlainNet::DebugString() const { diff --git a/paddle/framework/net.h b/paddle/framework/net.h index 33bb30ea0767b32e72888c9ff75970f8801f645a..19c5fa223b4e75b1f06ca14ded053cebfd8bffe2 100644 --- a/paddle/framework/net.h +++ b/paddle/framework/net.h @@ -16,7 +16,6 @@ limitations under the License. */ #include #include -#include "paddle/framework/net_proto.pb.h" #include "paddle/framework/op_proto.pb.h" #include "paddle/framework/op_registry.h" #include "paddle/framework/scope.h" @@ -41,7 +40,7 @@ namespace framework { class Net : public OperatorBase { public: virtual void AddOp(const OperatorPtr& op) = 0; - virtual void CompleteAddOp() = 0; + virtual void CompleteAddOp(bool calc) = 0; }; using NetPtr = std::shared_ptr; @@ -86,7 +85,7 @@ class PlainNet : public Net { ops_.push_back(op); } - void CompleteAddOp() override; + void CompleteAddOp(bool calculate = true) override; std::string DebugString() const override; diff --git a/paddle/framework/net_op_test.cc b/paddle/framework/net_op_test.cc index f5e1c22400a73c3aa09839ef9654f87def99bc77..e814a7e43d7ae7af0974d1a7c8b072bde5ba0238 100644 --- a/paddle/framework/net_op_test.cc +++ b/paddle/framework/net_op_test.cc @@ -63,5 +63,5 @@ TEST(OpKernel, all) { ASSERT_EQ(2, infer_shape_cnt); ASSERT_EQ(2, run_cnt); - ASSERT_THROW(net->AddOp(op2), paddle::framework::EnforceNotMet); + ASSERT_THROW(net->AddOp(op2), std::runtime_error); } diff --git a/paddle/framework/op_registry_test.cc b/paddle/framework/op_registry_test.cc index d3a51a361aa56b26b87d79057f6700bd87264ca4..32a7e88a894fb61a460443b7d593a6cf44bc98c5 100644 --- a/paddle/framework/op_registry_test.cc +++ b/paddle/framework/op_registry_test.cc @@ -91,7 +91,7 @@ TEST(OpRegistry, IllegalAttr) { try { paddle::framework::OperatorPtr op __attribute__((unused)) = paddle::framework::OpRegistry::CreateOp(op_desc); - } catch (paddle::framework::EnforceNotMet err) { + } catch (std::runtime_error& err) { caught = true; std::string msg = "larger_than check fail"; const char* err_msg = err.what(); @@ -138,7 +138,7 @@ TEST(OpRegistry, CustomChecker) { try { paddle::framework::OperatorPtr op __attribute__((unused)) = paddle::framework::OpRegistry::CreateOp(op_desc); - } catch (paddle::framework::EnforceNotMet err) { + } catch (std::runtime_error& err) { caught = true; std::string msg = "Attribute 'test_attr' is required!"; const char* err_msg = err.what(); @@ -157,7 +157,7 @@ TEST(OpRegistry, CustomChecker) { try { paddle::framework::OperatorPtr op __attribute__((unused)) = paddle::framework::OpRegistry::CreateOp(op_desc); - } catch (paddle::framework::EnforceNotMet err) { + } catch (std::runtime_error& err) { caught = true; std::string msg = "'test_attr' must be even!"; const char* err_msg = err.what(); @@ -196,7 +196,7 @@ TEST(ProtoMaker, DuplicatedAttr) { pd::OpProto op_proto; pd::OpAttrChecker op_checker; auto proto_maker = TestAttrProtoMaker(&op_proto, &op_checker); - ASSERT_THROW(proto_maker.Validate(), paddle::framework::EnforceNotMet); + ASSERT_THROW(proto_maker.Validate(), std::runtime_error); } class TestInOutProtoMaker : public pd::OpProtoAndCheckerMaker { @@ -212,5 +212,5 @@ TEST(ProtoMaker, DuplicatedInOut) { pd::OpProto op_proto; pd::OpAttrChecker op_checker; auto proto_maker = TestInOutProtoMaker(&op_proto, &op_checker); - ASSERT_THROW(proto_maker.Validate(), paddle::framework::EnforceNotMet); + ASSERT_THROW(proto_maker.Validate(), std::runtime_error); } diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index 1dd421cdb681e15486e309ff912574af35b5a0c2..93c6fad5d3d9f3de100d30161e6e438eb43816a2 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -19,9 +19,8 @@ limitations under the License. */ #include #include #include "paddle/framework/ddim.h" -#include "paddle/framework/enforce.h" -#include "paddle/framework/tensor_types.h" #include "paddle/memory/memory.h" +#include "paddle/platform/enforce.h" #include "paddle/platform/place.h" #include "unsupported/Eigen/CXX11/Tensor" @@ -35,6 +34,15 @@ struct CastToPyBufferImpl; namespace framework { class Tensor { + template + friend struct paddle::pybind::details::CastToPyBufferImpl; + + template + friend struct EigenTensor; + + template + friend struct EigenVector; + public: Tensor() : offset_(0) {} @@ -46,7 +54,7 @@ class Tensor { } template - T* raw_data() const { + T* data() { CheckDims(); return reinterpret_cast(reinterpret_cast(holder_->ptr()) + offset_); @@ -86,66 +94,6 @@ class Tensor { offset_); } - template - typename TTypes::Tensor shaped(DDim new_dims) { - Eigen::array dims = - paddle::framework::ToEigenDSizes(new_dims); - return typename TTypes::Tensor(raw_data(), dims); - } - - template - typename TTypes::Tensor tensor() { - return typename TTypes::Tensor( - raw_data(), paddle::framework::ToEigenDSizes(dims_)); - } - - // flat to rank = 1 - template - typename TTypes::Flat flat() { - return shaped(make_ddim({static_cast(product(dims_))})); - } - - // to TensorType Vec - template - typename TTypes::Vec vec() { - return tensor(); - } - - // to TensorType Matrix - template - typename TTypes::Matrix matrix() { - return tensor(); - } - - // const versions of all the methods above. - template - typename TTypes::Tensor shaped(DDim new_dims) const { - Eigen::array dims = - paddle::framework::ToEigenDSizes(new_dims); - return typename TTypes::Tensor(data(), dims); - } - - template - typename TTypes::ConstantTensor tensor() const { - return typename TTypes::Tensor( - data(), paddle::framework::ToEigenDSizes(dims_)); - } - - template - typename TTypes::ConstFlat flat() const { - return shaped(make_ddim({static_cast(product(dims_))})); - } - - template - typename TTypes::ConstVec vec() const { - return tensor(); - } - - template - typename TTypes::ConstMatrix matrix() const { - return tensor(); - } - template void ShareDataFrom(const Tensor& src) { src.CheckDims(); @@ -251,8 +199,6 @@ class Tensor { std::shared_ptr holder_; // holds the memory block if allocated. DDim dims_; size_t offset_; // marks the begin of tensor data area. - template - friend struct paddle::pybind::details::CastToPyBufferImpl; }; } // namespace framework diff --git a/paddle/framework/tensor_test.cc b/paddle/framework/tensor_test.cc index 84c6f0cf6558819440458688ca52b06c1cf11dd0..8a7cbbd0de6fd6aaafa8649abb8628e971bc49c1 100644 --- a/paddle/framework/tensor_test.cc +++ b/paddle/framework/tensor_test.cc @@ -33,7 +33,7 @@ TEST(Tensor, DataAssert) { bool caught = false; try { src_tensor.data(); - } catch (paddle::framework::EnforceNotMet err) { + } catch (std::runtime_error& err) { caught = true; std::string msg = "Tenosr holds no memory. Call Tensor::mutable_data first."; @@ -107,7 +107,7 @@ TEST(Tensor, ShareDataFrom) { bool caught = false; try { dst_tensor.ShareDataFrom(src_tensor); - } catch (EnforceNotMet err) { + } catch (std::runtime_error& err) { caught = true; std::string msg = "Tenosr holds no memory. Call Tensor::mutable_data first."; diff --git a/paddle/framework/tensor_types.h b/paddle/framework/tensor_types.h deleted file mode 100644 index 4bf27a377e828a56f9679e6698d314457d7caf0b..0000000000000000000000000000000000000000 --- a/paddle/framework/tensor_types.h +++ /dev/null @@ -1,67 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "unsupported/Eigen/CXX11/Tensor" - -namespace paddle { -namespace framework { - -// Helper to define Tensor types given that the scalar is of type T. -template -struct TTypes { - // Rank- tensor of scalar type T. - typedef Eigen::TensorMap, - Eigen::Aligned> - Tensor; - typedef Eigen::TensorMap< - Eigen::Tensor, Eigen::Aligned> - ConstTensor; - - // Scalar tensor (implemented as a rank-0 tensor) of scalar type T. - typedef Eigen::TensorMap< - Eigen::TensorFixedSize, Eigen::RowMajor, IndexType>, - Eigen::Aligned> - Scalar; - typedef Eigen::TensorMap, - Eigen::RowMajor, IndexType>, - Eigen::Aligned> - ConstScalar; - - // Rank-1 tensor (vector) of scalar type T. - typedef Eigen::TensorMap, - Eigen::Aligned> - Flat; - typedef Eigen::TensorMap< - Eigen::Tensor, Eigen::Aligned> - ConstFlat; - typedef Eigen::TensorMap, - Eigen::Aligned> - Vec; - typedef Eigen::TensorMap< - Eigen::Tensor, Eigen::Aligned> - ConstVec; - - // Rank-2 tensor (matrix) of scalar type T. - typedef Eigen::TensorMap, - Eigen::Aligned> - Matrix; - typedef Eigen::TensorMap< - Eigen::Tensor, Eigen::Aligned> - ConstMatrix; -}; - -} // namespace framework -} // namespace paddle diff --git a/paddle/function/CMakeLists.txt b/paddle/function/CMakeLists.txt index a5b14c0c71c18da1bb0b506c663f8680b1c3830a..2bec00cdb2d32d01a5a24e662bcca07f4154939c 100644 --- a/paddle/function/CMakeLists.txt +++ b/paddle/function/CMakeLists.txt @@ -36,6 +36,7 @@ if(WITH_GPU) add_simple_unittest(MulOpTest) add_simple_unittest(CosSimOpTest) add_simple_unittest(RowConvOpTest) + add_simple_unittest(CropOpTest) endif() add_simple_unittest(ConvOpTest) diff --git a/paddle/function/CropOp.cpp b/paddle/function/CropOp.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f12ee43e3d72f9ac776eaff93914228850694dd2 --- /dev/null +++ b/paddle/function/CropOp.cpp @@ -0,0 +1,177 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "CropOp.h" +#include "paddle/function/TensorShape.h" +#include "paddle/math/Vector.h" + +namespace paddle { + +template <> +void Crop(real* outputs, + const real* inputs, + const TensorShape inShape, + const TensorShape outShape, + const FuncConfig& conf) { + std::vector crop_corner = + conf.get>("crop_corner"); + int cCrop = crop_corner[1]; + int hCrop = crop_corner[2]; + int wCrop = crop_corner[3]; + + int num = inShape[0]; + int inC = inShape[1]; + int inH = inShape[2]; + int inW = inShape[3]; + + int outC = outShape[1]; + int outH = outShape[2]; + int outW = outShape[3]; + + for (int n = 0; n < num; n++) { + for (int c = 0; c < outC; c++) { + for (int h = 0; h < outH; h++) { + int outoff = ((n * outC + c) * outH + h) * outW; + int inoff = ((n * inC + c + cCrop) * inH + h + hCrop) * inW + wCrop; + memcpy(outputs + outoff, inputs + inoff, outW * sizeof(real)); + } + } + } +} + +template <> +void CropGrad(const real* inGrad, + real* outGrad, + const TensorShape inShape, + const TensorShape outShape, + const FuncConfig& conf) { + std::vector crop_corner = + conf.get>("crop_corner"); + int cCrop = crop_corner[1]; + int hCrop = crop_corner[2]; + int wCrop = crop_corner[3]; + + int num = outShape[0]; + int outC = outShape[1]; + int outH = outShape[2]; + int outW = outShape[3]; + + int inC = inShape[1]; + int inH = inShape[2]; + int inW = inShape[3]; + + for (int n = 0; n < num; n++) { + for (int c = 0; c < inC; c++) { + for (int h = 0; h < inH; h++) { + int outoff = ((n * outC + c + cCrop) * outH + h + hCrop) * outW + wCrop; + int inoff = ((n * inC + c) * inH + h) * inW; + CpuVector inG = CpuVector(inW, const_cast(inGrad + inoff)); + CpuVector outG = CpuVector(inW, outGrad + outoff); + outG += inG; + } + } + } +} + +/** + * \brief Crop input according to the specify corner and shape. + * The input and output is a 4D tensor. In CropFunc, we only + * crop the 2nd to 4th dimension. + * + * Argument in this Function: + * \param pad_ A struct object contains the cropping corner and shape. + * \param inputs A 4D tensor, only one input. + * \param outputs A 4D tensor, the output value after cropping. + * + * For example, + * Input(2,2,2,3) = [ + * [ [[1,2,3], [3,4,5]], + * [[2,3,5], [1,6,7]] ], + * [ [[4,3,1], [1,8,7]], + * [[3,8,9], [2,3,5]] ] + * ] # the input shape is (2,2,2,3) + * + * pad_: if corner = (0,1,1) and crop_shape = (2,1,2) + * Output(2,2,1,2) = [ + * [ [[4,5]], + * [[6,7]] ], + * [ [[8,7]], + * [[3,5]] ] + * ] # the input shape is (2,2,2,3) + */ +template +class CropFunc : public FunctionBase { +public: + void init(const FuncConfig& config) override { conf_ = config; } + + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_EQ(1UL, inputs.size()); + CHECK_EQ(1UL, outputs.size()); + CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO); + + TensorShape inShape = inputs[0].shape(); + TensorShape outShape = outputs[0].shape(); + + Crop(outputs[0].data(), + inputs[0].data(), + inShape, + outShape, + conf_); + } + +private: + FuncConfig conf_; +}; + +/** + * \brief The backward propagation of cropping Function. + * + * Argument in this Function: + * \param crop_ The same meaning as it in CropFunc. + * \param inputs The gradient with respect to the output value of CropFunc. + * \param outputs The gradient with respect to the input value of CropFunc. + */ + +template +class CropGradFunc : public FunctionBase { +public: + void init(const FuncConfig& config) override { conf_ = config; } + + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_EQ(1UL, inputs.size()); + CHECK_EQ(1UL, outputs.size()); + CHECK_EQ(outputs[0].getArgType(), ADD_TO); + + TensorShape outShape = outputs[0].shape(); + TensorShape inShape = inputs[0].shape(); + + CropGrad(inputs[0].data(), + outputs[0].data(), + inShape, + outShape, + conf_); + } + +private: + FuncConfig conf_; +}; + +REGISTER_TYPED_FUNC(Crop, CPU, CropFunc); +REGISTER_TYPED_FUNC(CropGrad, CPU, CropGradFunc); +#ifndef PADDLE_ONLY_CPU +REGISTER_TYPED_FUNC(Crop, GPU, CropFunc); +REGISTER_TYPED_FUNC(CropGrad, GPU, CropGradFunc); +#endif + +} // namespace paddle diff --git a/paddle/function/CropOp.h b/paddle/function/CropOp.h new file mode 100644 index 0000000000000000000000000000000000000000..87986fbdc7e33aeb24d947e82a5d67ba23f532de --- /dev/null +++ b/paddle/function/CropOp.h @@ -0,0 +1,51 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "Function.h" + +namespace paddle { + +/** + * \brief This funtion crops inputs according to the specify start point and + *shape. + * + * \param[out] outputs save results. + * \param[in] inputs input data. + * \param[in] inShape the shape of input tensor. + * \param[in] conf the cropping config + */ +template +void Crop(real* outputs, + const real* inputs, + const TensorShape inShape, + const TensorShape outShape, + const FuncConfig& conf); + +/** + * \brief Cropping operation backward. + * + * \param[out] inGrad gradients of previous layer + * \param[in] outGrad output gradient + * \param[in] inShape the shape of input tensor. + * \param[in] conf the cropping config + */ +template +void CropGrad(const real* inGrad, + real* outGrad, + const TensorShape inShape, + const TensorShape outShape, + const FuncConfig& conf); +} // namespace paddle diff --git a/paddle/function/CropOpGpu.cu b/paddle/function/CropOpGpu.cu new file mode 100644 index 0000000000000000000000000000000000000000..37ce6de0647e5e06a231710b5a53089533de2407 --- /dev/null +++ b/paddle/function/CropOpGpu.cu @@ -0,0 +1,113 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "hl_base.h" +#include "CropOp.h" + +namespace paddle { + +__global__ void KeCrop(real* outputs, const real* inputs, + int inC, int inH, int inW, + int cropC, int cropH, int cropW, + int outC, int outH, int outW, int nthreads) { + const int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < nthreads) { + const int w = idx % outW; + const int h = (idx / outW) % outH; + const int c = (idx / outW / outH) % outC; + const int n = idx / outW / outH / outC; + + const int off = ((n * inC + c + cropC) * inH + h + cropH) * inW + cropW + w; + outputs[idx] = inputs[off]; + } +} + +template <> +void Crop(real* outputs, + const real* inputs, + const TensorShape inShape, + const TensorShape outShape, + const FuncConfig& conf) { + std::vector crop_corner = conf.get>("crop_corner"); + int cropC = crop_corner[1]; + int cropH = crop_corner[2]; + int cropW = crop_corner[3]; + + int num = inShape[0]; + int inC = inShape[1]; + int inH = inShape[2]; + int inW = inShape[3]; + + int outC = outShape[1]; + int outH = outShape[2]; + int outW = outShape[3]; + + size_t nth = num * outC * outH * outW; + int blockSize = 1024; + int gridSize = (nth + blockSize - 1) / blockSize; + + KeCrop<<>> + (outputs, inputs, inC, inH, inW, cropC, cropH, cropW, + outC, outH, outW, nth); + CHECK_SYNC("Crop"); +} + +__global__ void KeCropDiff(const real* inGrad, real* outGrad, + int inC, int inH, int inW, + int cropC, int cropH, int cropW, + int outC, int outH, int outW, int nthreads) { + const int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < nthreads) { + const int w = idx % inW; + const int h = (idx / inW) % inH; + const int c = (idx / inW / inH) % inC; + const int n = idx / inW / inH / inC; + + const int off = ((n * outC + c + cropC) * outH + h + cropH) * outW + cropW + w; + + outGrad[off] += inGrad[idx]; + } +} + +template <> +void CropGrad(const real* inGrad, + real* outGrad, + const TensorShape inShape, + const TensorShape outShape, + const FuncConfig& conf) { + std::vector crop_corner = conf.get>("crop_corner"); + int cropC = crop_corner[1]; + int cropH = crop_corner[2]; + int cropW = crop_corner[3]; + + int num = outShape[0]; + int outC = outShape[1]; + int outH = outShape[2]; + int outW = outShape[3]; + + int inC = inShape[1]; + int inH = inShape[2]; + int inW = inShape[3]; + + size_t nth = num * inC * inH * inW; + int blockSize = 1024; + int gridSize = (nth + blockSize - 1) / blockSize; + + KeCropDiff <<>> + (inGrad, outGrad, inC, inH, inW, cropC, cropH, cropW, + outC, outH, outW, nth); + CHECK_SYNC("CropGrad"); +} + +} // namespace paddle diff --git a/paddle/function/CropOpTest.cpp b/paddle/function/CropOpTest.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6f11abfdf6f752857e0a75c62fb2b5c089c206d9 --- /dev/null +++ b/paddle/function/CropOpTest.cpp @@ -0,0 +1,49 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "FunctionTest.h" + +namespace paddle { + +TEST(Crop, real) { + for (size_t numSamples : {5, 32}) { + for (size_t channels : {5, 5, 32}) { + for (size_t imgSizeH : {5, 33, 100}) { + for (size_t imgSizeW : {5, 32, 96}) { + VLOG(3) << " numSamples=" << numSamples << " channels=" << channels + << " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW; + for (bool test_grad : {false, true}) { + CpuGpuFuncCompare compare( + test_grad ? "CropGrad" : "Crop", + FuncConfig() + .set>("crop_corner", {0, 1, 1, 1}) + .set>("crop_shape", {0, 2, 3, 3})); + TensorShape inDims{numSamples, channels, imgSizeH, imgSizeW}; + TensorShape outDims{numSamples, 2, 3, 3}; + compare.addInputs( + BufferArg(VALUE_TYPE_FLOAT, test_grad ? outDims : inDims)); + compare.addOutputs(BufferArg(VALUE_TYPE_FLOAT, + test_grad ? inDims : outDims, + test_grad ? ADD_TO : ASSIGN_TO), + test_grad ? ADD_TO : ASSIGN_TO); + compare.run(); + } + } + } + } + } +} + +} // namespace paddle diff --git a/paddle/gserver/layers/CropLayer.cpp b/paddle/gserver/layers/CropLayer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..69ad913420bdb6e1b2ed0618b7f9b78d7477be99 --- /dev/null +++ b/paddle/gserver/layers/CropLayer.cpp @@ -0,0 +1,146 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "CropLayer.h" +#include "paddle/utils/Stat.h" +namespace paddle { + +REGISTER_LAYER(crop, CropLayer); + +bool CropLayer::init(const LayerMap& layerMap, + const ParameterMap& parameterMap) { + /* Initialize the basic parent class */ + Layer::init(layerMap, parameterMap); + CHECK_LE(static_cast(inputLayers_.size()), 2); + CHECK_GE(static_cast(inputLayers_.size()), 1); + crop_axis_ = config_.axis(); + for (int i = 0; i < config_.offset_size(); i++) { + crop_offsets_.push_back(config_.offset(i)); + } + + // 1. get input_0 shape + auto& input0_img_conf = config_.inputs(0).image_conf(); + inDims_ = TensorShape({0, + input0_img_conf.channels(), + input0_img_conf.has_img_size_y() + ? input0_img_conf.img_size_y() + : input0_img_conf.img_size(), + input0_img_conf.img_size()}); + // 2. get target dims from config + if (config_.inputs_size() == 1) { + targetDims_ = TensorShape({config_.shape(0), + config_.shape(1), + config_.shape(2), + config_.shape(3)}); + } else { + // 2. get input_1 shape + auto& input1_img_conf = config_.inputs(1).image_conf(); + targetDims_ = TensorShape({0, + input1_img_conf.channels(), + input1_img_conf.has_img_size_y() + ? input1_img_conf.img_size_y() + : input1_img_conf.img_size(), + input1_img_conf.img_size()}); + } + + // 3. get final crop corner + int dimSize = 4; + crop_corner_ = {0, 0, 0, 0}; + for (int i = 0; i < dimSize; i++) { + if (i >= crop_axis_) { + if (crop_offsets_.size() > 1) { + crop_corner_[i] = crop_offsets_[i - crop_axis_]; + } else { + crop_corner_[i] = crop_offsets_[0]; + } + } + } + + outDims_ = TensorShape(4); + + createFunction( + forward_, "Crop", FuncConfig().set("crop_corner", crop_corner_)); + createFunction( + backward_, "CropGrad", FuncConfig().set("crop_corner", crop_corner_)); + + return true; +} + +void CropLayer::setOutDims() { + MatrixPtr input = inputLayers_[1]->getOutputValue(); + size_t batchSize = input->getHeight(); + // get target dims from input_1 + if (config_.inputs_size() == 2) { + targetDims_.setDim(0, batchSize); + int ch = config_.inputs(0).image_conf().channels(); + if (ch != 0) targetDims_.setDim(1, ch); + int h = inputLayers_[1]->getOutput().getFrameHeight(); + if (h != 0) targetDims_.setDim(2, h); + int w = inputLayers_[1]->getOutput().getFrameWidth(); + if (w != 0) targetDims_.setDim(3, w); + } + // get final crop shape from target dims and crop axis + std::vector crop_shape; + int dimSize = 4; + for (int i = 0; i < dimSize; i++) { + if (i >= crop_axis_) { + crop_shape.push_back(targetDims_[i]); + } else { + crop_shape.push_back(inDims_[i]); + } + } + + outDims_.reshape( + {crop_shape[0], crop_shape[1], crop_shape[2], crop_shape[3]}); + output_.setFrameHeight(crop_shape[2]); + output_.setFrameWidth(crop_shape[3]); +} + +void CropLayer::setInDims() { + MatrixPtr input = inputLayers_[0]->getOutputValue(); + size_t batchSize = input->getHeight(); + inDims_.setDim(0, batchSize); + int h = inputLayers_[0]->getOutput().getFrameHeight(); + if (h != 0) inDims_.setDim(2, h); + int w = inputLayers_[0]->getOutput().getFrameWidth(); + if (w != 0) inDims_.setDim(3, w); +} + +void CropLayer::forward(PassType passType) { + Layer::forward(passType); + setInDims(); + setOutDims(); + int size = outDims_[1] * outDims_[2] * outDims_[3]; + resetOutput(outDims_[0], size); + MatrixPtr outV = getOutputValue(); + REGISTER_TIMER_INFO("CropForward", getName().c_str()); + + BufferArgs inputs; + BufferArgs outputs; + inputs.addArg(*getInputValue(0), inDims_); + outputs.addArg(*getOutputValue(), outDims_, ASSIGN_TO); + forward_[0]->calc(inputs, outputs); +} + +void CropLayer::backward(const UpdateCallback& callback) { + (void)callback; + REGISTER_TIMER_INFO("CropBackward", getName().c_str()); + + BufferArgs inputs; + BufferArgs outputs; + inputs.addArg(*getOutputGrad(), outDims_); + outputs.addArg(*getInputGrad(0), inDims_, ADD_TO); + backward_[0]->calc(inputs, outputs); +} +} // namespace paddle diff --git a/paddle/gserver/layers/CropLayer.h b/paddle/gserver/layers/CropLayer.h new file mode 100644 index 0000000000000000000000000000000000000000..6b6202621023575c1c83049ecbd019656c726e3f --- /dev/null +++ b/paddle/gserver/layers/CropLayer.h @@ -0,0 +1,52 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "Layer.h" + +namespace paddle { + +/** + * \brief This layer crop input according to the specify conf. + * input_0: input to be cropped + * input_1: optional reference input + * axis: start dimension to be croped + * offset: offset of cropping in each dimension + * shape: if reference input layer was not setted, + * crop input as this shape conf + */ +class CropLayer : public Layer { +public: + explicit CropLayer(const LayerConfig& config) : Layer(config) {} + + ~CropLayer() {} + + bool init(const LayerMap& layerMap, + const ParameterMap& parameterMap) override; + void forward(PassType passType) override; + void backward(const UpdateCallback& callback = nullptr) override; + +protected: + void setOutDims(); + void setInDims(); + + int32_t crop_axis_; + std::vector crop_offsets_; + std::vector crop_corner_; + TensorShape inDims_; + TensorShape targetDims_; + TensorShape outDims_; +}; +} // namespace paddle diff --git a/paddle/gserver/tests/CMakeLists.txt b/paddle/gserver/tests/CMakeLists.txt index 92f6cbcfe5a0e23c5939b1689a3e339367450387..a43adc7ce7db937bd62ea9bf1533b8a5899c259a 100644 --- a/paddle/gserver/tests/CMakeLists.txt +++ b/paddle/gserver/tests/CMakeLists.txt @@ -56,7 +56,7 @@ add_test(NAME test_DetectionOutput add_unittest_without_exec(test_ConvUnify test_ConvUnify.cpp LayerGradUtil.cpp) - + add_test(NAME test_ConvUnify COMMAND test_ConvUnify) ################# test_BatchNorm ####################### diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp index 2b45483bcc7faf13f56064c8d7f520b0f26c9a39..0975c3bc9573c6ccb8f0ac98c41586d322d2465e 100644 --- a/paddle/gserver/tests/test_LayerGrad.cpp +++ b/paddle/gserver/tests/test_LayerGrad.cpp @@ -1851,6 +1851,34 @@ TEST(Layer, RowConvLayer) { } } +TEST(Layer, CropLayer) { + TestConfig config; + // config input_0 + config.inputDefs.push_back({INPUT_DATA, "layer_0", 1024, 0}); + LayerInputConfig* input = config.layerConfig.add_inputs(); + ImageConfig* img = input->mutable_image_conf(); + img->set_channels(4); + img->set_img_size(16); + config.layerConfig.set_axis(2); + config.layerConfig.add_offset(0); + config.layerConfig.add_offset(0); + + // config input_1 + config.inputDefs.push_back({INPUT_DATA, "layer_1", 128, 0}); + input = config.layerConfig.add_inputs(); + img = input->mutable_image_conf(); + img->set_channels(2); + img->set_img_size(8); + + // config crop layer + config.layerConfig.set_type("crop"); + config.layerConfig.set_name("cropLayer"); + + for (auto useGpu : {false, true}) { + testLayerGrad(config, "crop", 100, false, useGpu, false); + } +} + int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); initMain(argc, argv); diff --git a/paddle/memory/detail/system_allocator.cc b/paddle/memory/detail/system_allocator.cc index 1579174b1a6ff08824629d833d01411cff651f48..f61e67a32906083881dd7f47433521876be9b355 100644 --- a/paddle/memory/detail/system_allocator.cc +++ b/paddle/memory/detail/system_allocator.cc @@ -14,7 +14,7 @@ limitations under the License. */ #include "paddle/memory/detail/system_allocator.h" #include "paddle/platform/assert.h" -#include "paddle/platform/error.h" +#include "paddle/platform/enforce.h" #include "paddle/platform/gpu_info.h" #include // for malloc and free @@ -128,8 +128,7 @@ void GPUAllocator::Free(void* p, size_t size, size_t index) { // process is terminating, in which case we don't care if // cudaFree succeeds. if (err != cudaErrorCudartUnloading) { - platform::throw_on_error(err, - "cudaFree{Host} failed in GPUAllocator::Free."); + PADDLE_ENFORCE(err, "cudaFree{Host} failed in GPUAllocator::Free."); } } diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index f47c3a42083f289d6c99fe6df62e3478e0363e31..a37720e5093342f5e02bd9a15a3099de434d6396 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -27,7 +27,8 @@ function(op_library TARGET) endif() list(LENGTH cu_srcs cu_srcs_len) - if (${cu_srcs_len} EQUAL 0) + list(LENGTH op_library_DEPS dep_len) + if (${cu_srcs_len} EQUAL 0 AND ${dep_len} EQUAL 0) message(WARNING "The op library ${TARGET} not support GPU!") endif() @@ -47,3 +48,8 @@ op_library(mul_op SRCS mul_op.cc mul_op.cu) op_library(rowwise_add_op SRCS rowwise_add_op.cu rowwise_add_op.cc) op_library(sigmoid_op SRCS sigmoid_op.cu sigmoid_op.cc) op_library(softmax_op SRCS softmax_op.cc softmax_op.cu) + +op_library(fc_op SRCS fc_op.cc DEPS mul_op rowwise_add_op sigmoid_op + softmax_op net) + +op_library(sgd_op SRCS sgd_op.cc sgd_op.cu) diff --git a/paddle/operators/add_op.h b/paddle/operators/add_op.h index e08b3fb18775e2536a13bc838f40472c5c3e7ff7..39d54a63bd16cdafeec1cfcd86ef5d142382e880 100644 --- a/paddle/operators/add_op.h +++ b/paddle/operators/add_op.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once #include "glog/logging.h" +#include "paddle/framework/eigen.h" #include "paddle/framework/operator.h" namespace paddle { @@ -29,8 +30,10 @@ public: output->mutable_data(context.GetPlace()); - output->flat().device(*(context.GetEigenDevice())) = - input0.flat() + input1.flat(); + framework::EigenVector::Flatten(*output).device( + *(context.GetEigenDevice())) = + framework::EigenVector::Flatten(input0) + + framework::EigenVector::Flatten(input1); } }; diff --git a/paddle/operators/fc_op.cc b/paddle/operators/fc_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..01e96f4c4817466e3266ca57a0d0ae2368b3e097 --- /dev/null +++ b/paddle/operators/fc_op.cc @@ -0,0 +1,76 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/framework/net.h" +#include "paddle/framework/op_registry.h" +#include "paddle/framework/operator.h" + +namespace paddle { +namespace operators { + +class FullyConnectedOp : public framework::PlainNet { +public: + void Init() override { + AddOp(framework::OpRegistry::CreateOp("mul", + { + Input("X"), Input("W"), + }, + {Output("before_act")}, + {})); + auto b = Input("b"); + if (b != framework::OperatorBase::EMPTY_VAR_NAME()) { + AddOp(framework::OpRegistry::CreateOp("rowwise_add", + {Output("before_act"), Input("b")}, + {Output("before_act")}, + {})); + } + + auto activation = GetAttr("activation"); + AddOp(framework::OpRegistry::CreateOp( + activation, {Output("before_act")}, {Output("Y")}, {})); + CompleteAddOp(false); + } +}; + +class FullyConnectedOpMaker : public framework::OpProtoAndCheckerMaker { +public: + FullyConnectedOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "the input of fc operator"); + AddInput("W", "the weight of fc operator"); + AddInput("b", "the bias of fc operator"); + + AddOutput("Y", "the output of fc operator"); + AddOutput( + "before_act", "the before activation output of fc operator", true); + AddAttr("activation", "The activation key for fc layer") + .SetDefault("sigmoid") + .InEnum({"sigmoid", "softmax"}); + + //! TODO(yuyang18): Complete comment; + AddComment("FullyConnected Operator"); + } +}; +} // namespace operators +} // namespace paddle + +USE_OP(mul); +USE_OP(rowwise_add); +USE_OP(sigmoid); +USE_OP(softmax); + +REGISTER_OP(fc, + paddle::operators::FullyConnectedOp, + paddle::operators::FullyConnectedOpMaker); diff --git a/paddle/operators/sgd_op.cc b/paddle/operators/sgd_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..04df87a3add2af7daa127a072f7b690f6cf94327 --- /dev/null +++ b/paddle/operators/sgd_op.cc @@ -0,0 +1,61 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/sgd_op.h" +#include "paddle/framework/op_registry.h" +#include "paddle/framework/tensor.h" + +namespace paddle { +namespace operators { + +class SGDOp : public framework::OperatorWithKernel { +protected: + void InferShape( + const std::vector &inputs, + const std::vector &outputs) const override { + PADDLE_ENFORCE(inputs.size() == 2, "Input size of SGDOp must be two"); + PADDLE_ENFORCE(outputs.size() == 1, "Output size of SGDOp must be one"); + PADDLE_ENFORCE(inputs[0] != nullptr, "inputs[0] mast be set"); + PADDLE_ENFORCE(inputs[1] != nullptr, "inputs[1] mast be set"); + PADDLE_ENFORCE(outputs[0] != nullptr, "outputs[0] mast be set"); + PADDLE_ENFORCE(inputs[0]->dims() == inputs[1]->dims(), + "Two input of SGD Op's dimension must be same."); + outputs[0]->set_dims(inputs[0]->dims()); + } +}; + +class SGDOpMaker : public framework::OpProtoAndCheckerMaker { +public: + SGDOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : framework::OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("param", "input parameter"); + AddInput("grad", "input gradient"); + AddOutput("param_out", "output parameter"); + AddAttr("learning_rate", "learning rate of sgd"); + AddComment(R"DOC( + +Simplest sgd algorithm. + +param_out = param - learning_rate * grad; + +)DOC"); + } +}; +} // namespace operators +} // namespace paddle + +REGISTER_OP(sgd, paddle::operators::SGDOp, paddle::operators::SGDOpMaker); +typedef paddle::operators::SGDOpKernel<::paddle::platform::CPUPlace, float> + SGDOpKernel_CPU_float; +REGISTER_OP_CPU_KERNEL(sgd, SGDOpKernel_CPU_float); diff --git a/paddle/operators/sgd_op.cu b/paddle/operators/sgd_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..400425db10896e3970fc7468e34aba596a536184 --- /dev/null +++ b/paddle/operators/sgd_op.cu @@ -0,0 +1,5 @@ +#include "paddle/operators/sgd_op.h" +#include "paddle/framework/op_registry.h" + +typedef paddle::operators::SGDOpKernel<::paddle::platform::GPUPlace, float> SGDOpKernel_GPU_float; +REGISTER_OP_GPU_KERNEL(sgd, SGDOpKernel_GPU_float); \ No newline at end of file diff --git a/paddle/operators/sgd_op.h b/paddle/operators/sgd_op.h new file mode 100644 index 0000000000000000000000000000000000000000..4b2d214618e5c7c15695bd66604139d805255c47 --- /dev/null +++ b/paddle/operators/sgd_op.h @@ -0,0 +1,42 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "glog/logging.h" +#include "paddle/framework/eigen.h" +#include "paddle/framework/operator.h" + +namespace paddle { +namespace operators { + +template +class SGDOpKernel : public framework::OpKernel { +public: + void Compute(const framework::KernelContext& ctx) const override { + auto param = ctx.Input("param")->Get(); + auto grad = ctx.Input("grad")->Get(); + auto* param_out = ctx.Output(0)->GetMutable(); + float lr = ctx.op_.GetAttr("learning_rate"); + + param_out->mutable_data(ctx.GetPlace()); + + framework::EigenVector::Flatten(*param_out) + .device(*(ctx.GetEigenDevice())) = + framework::EigenVector::Flatten(param) - + lr * framework::EigenVector::Flatten(grad); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/sgd_op_test.cc b/paddle/operators/sgd_op_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..75137259f5e608b259b073101353e5818bb17c92 --- /dev/null +++ b/paddle/operators/sgd_op_test.cc @@ -0,0 +1,22 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +USE_OP(sgd); +TEST(SGDOp, GetOpProto) { + auto& protos = paddle::framework::OpRegistry::protos(); + auto it = protos.find("sgd"); + ASSERT_NE(it, protos.end()); +} diff --git a/paddle/platform/CMakeLists.txt b/paddle/platform/CMakeLists.txt index 6ac4035c0f863c5f63d17b523a7a8be668ff3da0..bd77bb7daa50e0b273f110624ddf6f4b79a3ceab 100644 --- a/paddle/platform/CMakeLists.txt +++ b/paddle/platform/CMakeLists.txt @@ -8,6 +8,8 @@ cc_test(place_test SRCS place_test.cc DEPS place glog gflags) add_subdirectory(dynload) +cc_test(enforce_test SRCS enforce_test.cc) + IF(WITH_GPU) set(GPU_CTX_DEPS dynload_cuda dynamic_loader) ELSE() diff --git a/paddle/platform/cpu_info.cc b/paddle/platform/cpu_info.cc index dfab391cfbe1f04bc2a998233f7e7909579ca72b..78e1fa9df56b1623bfd9a53c6a37524d29648afc 100644 --- a/paddle/platform/cpu_info.cc +++ b/paddle/platform/cpu_info.cc @@ -22,7 +22,6 @@ limitations under the License. */ #endif #include "gflags/gflags.h" -#include "paddle/platform/error.h" DEFINE_double(fraction_of_cpu_memory_to_use, 1, "Default use 100% of CPU memory for PaddlePaddle," diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index f226a75c20b7a75e5f884cd158d139ebb8b34e47..fe6f13e399a78f9e5230ae52b0f67ab465af373b 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -11,12 +11,13 @@ limitations under the License. */ #pragma once -#include "paddle/framework/enforce.h" +#include "paddle/platform/enforce.h" +#include "paddle/platform/place.h" + #ifndef PADDLE_ONLY_CPU #include "paddle/platform/dynload/cublas.h" #include "paddle/platform/dynload/cudnn.h" #include "paddle/platform/dynload/curand.h" -#include "paddle/platform/error.h" #include "paddle/platform/gpu_info.h" #define EIGEN_USE_GPU #endif @@ -71,8 +72,7 @@ class CUDADeviceContext : public DeviceContext { public: explicit CUDADeviceContext(const GPUPlace gpu_place) : gpu_place_(gpu_place) { GPUPlaceGuard guard(gpu_place_); - paddle::platform::throw_on_error(cudaStreamCreate(&stream_), - "cudaStreamCreate failed"); + PADDLE_ENFORCE(cudaStreamCreate(&stream_), "cudaStreamCreate failed"); eigen_stream_.reset(new Eigen::CudaStreamDevice(&stream_)); eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get())); } @@ -83,8 +83,8 @@ class CUDADeviceContext : public DeviceContext { } void Wait() { - paddle::platform::throw_on_error(cudaStreamSynchronize(stream_), - "cudaStreamSynchronize failed"); + PADDLE_ENFORCE(cudaStreamSynchronize(stream_), + "cudaStreamSynchronize failed"); } cudaStream_t stream() { return stream_; } @@ -94,12 +94,11 @@ class CUDADeviceContext : public DeviceContext { cublasHandle_t cublas_handle() { if (!blas_handle_) { GPUPlaceGuard guard(gpu_place_); - PADDLE_ENFORCE(paddle::platform::dynload::cublasCreate(&blas_handle_) == - CUBLAS_STATUS_SUCCESS, + PADDLE_ENFORCE(paddle::platform::dynload::cublasCreate(&blas_handle_), "cublasCreate failed"); - PADDLE_ENFORCE(paddle::platform::dynload::cublasSetStream( - blas_handle_, stream_) == CUBLAS_STATUS_SUCCESS, - "cublasSetStream failed"); + PADDLE_ENFORCE( + paddle::platform::dynload::cublasSetStream(blas_handle_, stream_), + "cublasSetStream failed"); } return blas_handle_; } @@ -107,12 +106,11 @@ class CUDADeviceContext : public DeviceContext { cudnnHandle_t cudnn_handle() { if (!dnn_handle_) { GPUPlaceGuard guard(gpu_place_); - PADDLE_ENFORCE(paddle::platform::dynload::cudnnCreate(&dnn_handle_) == - CUDNN_STATUS_SUCCESS, + PADDLE_ENFORCE(paddle::platform::dynload::cudnnCreate(&dnn_handle_), "cudnnCreate failed"); - PADDLE_ENFORCE(paddle::platform::dynload::cudnnSetStream( - dnn_handle_, stream_) == CUDNN_STATUS_SUCCESS, - "cudnnSetStream failed"); + PADDLE_ENFORCE( + paddle::platform::dynload::cudnnSetStream(dnn_handle_, stream_), + "cudnnSetStream failed"); } return dnn_handle_; } @@ -121,16 +119,15 @@ class CUDADeviceContext : public DeviceContext { if (!rand_generator_) { GPUPlaceGuard guard(gpu_place_); PADDLE_ENFORCE(paddle::platform::dynload::curandCreateGenerator( - &rand_generator_, CURAND_RNG_PSEUDO_DEFAULT) == - CURAND_STATUS_SUCCESS, + &rand_generator_, CURAND_RNG_PSEUDO_DEFAULT), "curandCreateGenerator failed"); PADDLE_ENFORCE( paddle::platform::dynload::curandSetPseudoRandomGeneratorSeed( - rand_generator_, random_seed_) == CURAND_STATUS_SUCCESS, + rand_generator_, random_seed_), "curandSetPseudoRandomGeneratorSeed failed"); - PADDLE_ENFORCE(paddle::platform::dynload::curandSetStream( - rand_generator_, stream_) == CURAND_STATUS_SUCCESS, - "curandSetStream failed"); + PADDLE_ENFORCE( + paddle::platform::dynload::curandSetStream(rand_generator_, stream_), + "curandSetStream failed"); } return rand_generator_; } @@ -138,26 +135,23 @@ class CUDADeviceContext : public DeviceContext { ~CUDADeviceContext() { Wait(); if (blas_handle_) { - PADDLE_ENFORCE(paddle::platform::dynload::cublasDestroy(blas_handle_) == - CUBLAS_STATUS_SUCCESS, + PADDLE_ENFORCE(paddle::platform::dynload::cublasDestroy(blas_handle_), "cublasDestroy failed"); } if (dnn_handle_) { - PADDLE_ENFORCE(paddle::platform::dynload::cudnnDestroy(dnn_handle_) == - CUDNN_STATUS_SUCCESS, + PADDLE_ENFORCE(paddle::platform::dynload::cudnnDestroy(dnn_handle_), "cudnnDestroy failed"); } if (rand_generator_) { - PADDLE_ENFORCE(paddle::platform::dynload::curandDestroyGenerator( - rand_generator_) == CURAND_STATUS_SUCCESS, - "curandDestroyGenerator failed"); + PADDLE_ENFORCE( + paddle::platform::dynload::curandDestroyGenerator(rand_generator_), + "curandDestroyGenerator failed"); } eigen_stream_.reset(); eigen_device_.reset(); - paddle::platform::throw_on_error(cudaStreamDestroy(stream_), - "cudaStreamDestroy failed"); + PADDLE_ENFORCE(cudaStreamDestroy(stream_), "cudaStreamDestroy failed"); } private: diff --git a/paddle/platform/dynload/dynamic_loader.cc b/paddle/platform/dynload/dynamic_loader.cc index dd914e006d54c423ffea56ffaaafe7dcba416361..ae9a0a982c73de05821579d22b7f9ad99f24a92b 100644 --- a/paddle/platform/dynload/dynamic_loader.cc +++ b/paddle/platform/dynload/dynamic_loader.cc @@ -19,7 +19,7 @@ limitations under the License. */ #include #include "gflags/gflags.h" #include "glog/logging.h" -#include "paddle/framework/enforce.h" +#include "paddle/platform/enforce.h" DEFINE_string(cudnn_dir, "", "Specify path for loading libcudnn.so. For instance, " diff --git a/paddle/platform/enforce.h b/paddle/platform/enforce.h new file mode 100644 index 0000000000000000000000000000000000000000..5d440dec48e7a4cba404bc297eca5a451a144d93 --- /dev/null +++ b/paddle/platform/enforce.h @@ -0,0 +1,141 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include + +#ifndef PADDLE_ONLY_CPU + +#include "paddle/platform/dynload/cublas.h" +#include "paddle/platform/dynload/cudnn.h" +#include "paddle/platform/dynload/curand.h" + +#include +#include +#include +#include +#include + +#endif // PADDLE_ONLY_CPU + +namespace paddle { +namespace platform { + +// Because most enforce conditions would evaluate to true, we can use +// __builtin_expect to instruct the C++ compiler to generate code that +// always forces branch prediction of true. +// This generates faster binary code. __builtin_expect is since C++11. +// For more details, please check https://stackoverflow.com/a/43870188/724872. +#define UNLIKELY(condition) __builtin_expect(static_cast(condition), 0) + +#ifndef PADDLE_ONLY_CPU + +template +inline void throw_on_error(cudaError_t e, const Args&... args) { + if (UNLIKELY(e)) { + // clang-format off + throw thrust::system_error( + e, thrust::cuda_category(), + string::Sprintf(args...) + + string::Sprintf(" at [%s:%s];", __FILE__, __LINE__)); + // clang-format on + } +} + +template +inline void throw_on_error(curandStatus_t stat, const Args&... args) { + if (stat != CURAND_STATUS_SUCCESS) { + // clang-format off + throw thrust::system_error( + cudaErrorLaunchFailure, thrust::cuda_category(), + string::Sprintf(args...) + + string::Sprintf(" at [%s:%s];", __FILE__, __LINE__)); + // clang-format on + } +} + +template +inline void throw_on_error(cudnnStatus_t stat, const Args&... args) { + if (stat == CUDNN_STATUS_SUCCESS) { + return; + } else { + // clang-format off + throw std::runtime_error( + platform::dynload::cudnnGetErrorString(stat) + + string::Sprintf(args...) + + string::Sprintf(" at [%s:%s];", __FILE__, __LINE__)); + // clang-format on + } +} + +template +inline void throw_on_error(cublasStatus_t stat, const Args&... args) { + std::string err; + if (stat == CUBLAS_STATUS_SUCCESS) { + return; + } else if (stat == CUBLAS_STATUS_NOT_INITIALIZED) { + err = "CUBLAS: not initialized, "; + } else if (stat == CUBLAS_STATUS_ALLOC_FAILED) { + err = "CUBLAS: alloc failed, "; + } else if (stat == CUBLAS_STATUS_INVALID_VALUE) { + err = "CUBLAS: invalid value, "; + } else if (stat == CUBLAS_STATUS_ARCH_MISMATCH) { + err = "CUBLAS: arch mismatch, "; + } else if (stat == CUBLAS_STATUS_MAPPING_ERROR) { + err = "CUBLAS: mapping error, "; + } else if (stat == CUBLAS_STATUS_EXECUTION_FAILED) { + err = "CUBLAS: execution failed, "; + } else if (stat == CUBLAS_STATUS_INTERNAL_ERROR) { + err = "CUBLAS: internal error, "; + } else if (stat == CUBLAS_STATUS_NOT_SUPPORTED) { + err = "CUBLAS: not supported, "; + } else if (stat == CUBLAS_STATUS_LICENSE_ERROR) { + err = "CUBLAS: license error, "; + } + throw std::runtime_error(err + string::Sprintf(args...) + + string::Sprintf(" at [%s:%s];", __FILE__, __LINE__)); +} + +#endif // PADDLE_ONLY_CPU + +template +inline void throw_on_error(int stat, const Args&... args) { + if (UNLIKELY(!(stat))) { + throw std::runtime_error( + string::Sprintf(args...) + + string::Sprintf(" at [%s:%s];", __FILE__, __LINE__)); + } +} + +#define PADDLE_THROW(...) \ + do { \ + throw std::runtime_error( \ + string::Sprintf(__VA_ARGS__) + \ + string::Sprintf(" at [%s:%s];", __FILE__, __LINE__)); \ + } while (0) + +/** + * @brief Enforce a condition, otherwise throw an EnforceNotMet + */ +#define PADDLE_ENFORCE(condition, ...) \ + do { \ + ::paddle::platform::throw_on_error(condition, __VA_ARGS__); \ + } while (0) + +} // namespace platform +} // namespace paddle diff --git a/paddle/framework/enforce_test.cc b/paddle/platform/enforce_test.cc similarity index 85% rename from paddle/framework/enforce_test.cc rename to paddle/platform/enforce_test.cc index f8da1a192f63a54324d80725c9d2f156fb11a481..d7152f81509a35e4ce36d5649e7d209f51e34b86 100644 --- a/paddle/framework/enforce_test.cc +++ b/paddle/platform/enforce_test.cc @@ -9,8 +9,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include -#include +#include "paddle/platform/enforce.h" +#include "gtest/gtest.h" TEST(ENFORCE, OK) { PADDLE_ENFORCE(true, "Enforce is ok %d now %f", 123, 0.345); @@ -23,13 +23,14 @@ TEST(ENFORCE, FAILED) { bool in_catch = false; try { PADDLE_ENFORCE(false, "Enforce is not ok %d at all", 123); - } catch (paddle::framework::EnforceNotMet err) { + } catch (const std::runtime_error& error) { + // your error handling code here in_catch = true; std::string msg = "Enforce is not ok 123 at all"; - const char* what = err.what(); + const char* what = error.what(); for (size_t i = 0; i < msg.length(); ++i) { ASSERT_EQ(what[i], msg[i]); } } ASSERT_TRUE(in_catch); -} \ No newline at end of file +} diff --git a/paddle/platform/error.h b/paddle/platform/error.h deleted file mode 100644 index 93424bb61096503a4843da7942853a113f612e3b..0000000000000000000000000000000000000000 --- a/paddle/platform/error.h +++ /dev/null @@ -1,87 +0,0 @@ -#pragma once - -#include -#include -#include - -#ifndef PADDLE_ONLY_CPU - -#include -#include -#include -#include -#include - -#endif // PADDLE_ONLY_CPU - -namespace paddle { -namespace platform { - -#ifndef PADDLE_ONLY_CPU - -inline void throw_on_error(cudaError_t e, const char* message) { - if (e) { - throw thrust::system_error(e, thrust::cuda_category(), message); - } -} - -inline void throw_on_error(curandStatus_t stat, const char* message) { - if (stat != CURAND_STATUS_SUCCESS) { - throw thrust::system_error(cudaErrorLaunchFailure, thrust::cuda_category(), - message); - } -} - -inline void throw_on_error(cudnnStatus_t stat, const char* message) { - std::stringstream ss; - if (stat == CUDNN_STATUS_SUCCESS) { - return; - } else { - ss << cudnnGetErrorString(stat); - ss << ", " << message; - throw std::runtime_error(ss.str()); - } -} - -inline void throw_on_error(cublasStatus_t stat, const char* message) { - std::stringstream ss; - if (stat == CUBLAS_STATUS_SUCCESS) { - return; - } else if (stat == CUBLAS_STATUS_NOT_INITIALIZED) { - ss << "CUBLAS: not initialized"; - } else if (stat == CUBLAS_STATUS_ALLOC_FAILED) { - ss << "CUBLAS: alloc failed"; - } else if (stat == CUBLAS_STATUS_INVALID_VALUE) { - ss << "CUBLAS: invalid value"; - } else if (stat == CUBLAS_STATUS_ARCH_MISMATCH) { - ss << "CUBLAS: arch mismatch"; - } else if (stat == CUBLAS_STATUS_MAPPING_ERROR) { - ss << "CUBLAS: mapping error"; - } else if (stat == CUBLAS_STATUS_EXECUTION_FAILED) { - ss << "CUBLAS: execution failed"; - } else if (stat == CUBLAS_STATUS_INTERNAL_ERROR) { - ss << "CUBLAS: internal error"; - } else if (stat == CUBLAS_STATUS_NOT_SUPPORTED) { - ss << "CUBLAS: not supported"; - } else if (stat == CUBLAS_STATUS_LICENSE_ERROR) { - ss << "CUBLAS: license error"; - } - ss << ", " << message; - throw std::runtime_error(ss.str()); -} - -inline void throw_on_error(cublasStatus_t stat) { - const char* message = ""; - throw_on_error(stat, message); -} - -#endif // PADDLE_ONLY_CPU - -inline void throw_on_error(int stat, const char* message) { - if (stat) { - throw std::runtime_error(message + (", stat = " + std::to_string(stat))); - } -} - -} // namespace platform -} // namespace paddle diff --git a/paddle/platform/gpu_info.cc b/paddle/platform/gpu_info.cc index a1383d3524aedf834c329425419b989d47668bea..cf9921e870d47fe77c0cca80828dbf2bb36ccda8 100644 --- a/paddle/platform/gpu_info.cc +++ b/paddle/platform/gpu_info.cc @@ -14,7 +14,7 @@ limitations under the License. */ #include "paddle/platform/gpu_info.h" #include "gflags/gflags.h" -#include "paddle/platform/error.h" +#include "paddle/platform/enforce.h" DEFINE_double(fraction_of_gpu_memory_to_use, 0.95, "Default use 95% of GPU memory for PaddlePaddle," @@ -25,7 +25,7 @@ namespace platform { int GetDeviceCount() { int count; - throw_on_error( + PADDLE_ENFORCE( cudaGetDeviceCount(&count), "cudaGetDeviceCount failed in paddle::platform::GetDeviceCount"); return count; @@ -33,19 +33,19 @@ int GetDeviceCount() { int GetCurrentDeviceId() { int device_id; - throw_on_error( + PADDLE_ENFORCE( cudaGetDevice(&device_id), "cudaGetDevice failed in paddle::platform::GetCurrentDeviceId"); return device_id; } void SetDeviceId(int id) { - throw_on_error(cudaSetDevice(id), + PADDLE_ENFORCE(cudaSetDevice(id), "cudaSetDevice failed in paddle::platform::SetDeviceId"); } void GpuMemoryUsage(size_t& available, size_t& total) { - throw_on_error(cudaMemGetInfo(&available, &total), + PADDLE_ENFORCE(cudaMemGetInfo(&available, &total), "cudaMemGetInfo failed in paddle::platform::GetMemoryUsage"); } diff --git a/paddle/pybind/CMakeLists.txt b/paddle/pybind/CMakeLists.txt index 00b14a94321990baef6de35df547eed04b3da04f..6354dd211d5d036e1b5971babaf624e8f847a92b 100644 --- a/paddle/pybind/CMakeLists.txt +++ b/paddle/pybind/CMakeLists.txt @@ -1,2 +1,2 @@ cc_library(paddle_pybind SHARED SRCS pybind.cc DEPS pybind python - add_op mul_op rowwise_add_op sigmoid_op softmax_op) + add_op fc_op sgd_op) diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index fc9c6544c3cbf5a804b2d052f738bd483d6bf41b..54707a2859693af4a80692bf5cebab59c43ffbc3 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include #include +#include #include #include #include @@ -26,10 +27,8 @@ namespace py = pybind11; namespace pd = paddle::framework; USE_OP(add_two); -USE_OP(softmax); -USE_OP(mul); -USE_OP(rowwise_add); -USE_OP(sigmoid); +USE_OP_WITHOUT_KERNEL(fc); +USE_OP(sgd); PYBIND11_PLUGIN(core) { py::module m("core", "C++ core of Paddle Paddle"); @@ -53,7 +52,9 @@ PYBIND11_PLUGIN(core) { self.mutable_data(paddle::platform::CPUPlace()); }) .def("set", paddle::pybind::PyTensorSetFromArray) - .def("set", paddle::pybind::PyTensorSetFromArray); + .def("set", paddle::pybind::PyTensorSetFromArray) + .def("shape", + [](pd::Tensor& self) { return pd::vectorize(self.dims()); }); py::class_(m, "Variable", R"DOC(Variable Class. @@ -83,15 +84,16 @@ All parameter, weight, gradient are variables in Paddle. //! @note: Be careful! PyBind will return std::string as an unicode, not //! Python str. If you want a str object, you should cast them in Python. - m.def("get_all_op_protos", []() -> std::vector { + m.def("get_all_op_protos", []() -> std::vector { auto& protos = pd::OpRegistry::protos(); - std::vector ret_values; + std::vector ret_values; for (auto it = protos.begin(); it != protos.end(); ++it) { PADDLE_ENFORCE(it->second.IsInitialized(), "OpProto must all be initialized"); - ret_values.emplace_back(); - PADDLE_ENFORCE(it->second.SerializeToString(&ret_values.back()), + std::string str; + PADDLE_ENFORCE(it->second.SerializeToString(&str), "Serialize OpProto Error. This could be a bug of Paddle."); + ret_values.push_back(py::bytes(str)); } return ret_values; }); @@ -101,17 +103,26 @@ All parameter, weight, gradient are variables in Paddle. .def("empty", pd::OperatorBase::EMPTY_VAR_NAME) .def("temp", pd::OperatorBase::TMP_VAR_NAME); + py::class_(m, "DeviceContext") + .def_static("cpu_context", []() -> paddle::platform::DeviceContext* { + return new paddle::platform::CPUDeviceContext(); + }); + py::class_(m, "Operator") .def("__str__", &pd::OperatorBase::DebugString) - .def_static("create", [](const std::string& protobin) { - pd::OpDesc desc; - PADDLE_ENFORCE(desc.ParsePartialFromString(protobin), - "Cannot parse user input to OpDesc"); - PADDLE_ENFORCE(desc.IsInitialized(), - "User OpDesc is not initialized, reason %s", - desc.InitializationErrorString()); - return pd::OpRegistry::CreateOp(desc); - }); + .def_static("create", + [](py::bytes protobin) { + pd::OpDesc desc; + PADDLE_ENFORCE(desc.ParsePartialFromString(protobin), + "Cannot parse user input to OpDesc"); + PADDLE_ENFORCE(desc.IsInitialized(), + "User OpDesc is not initialized, reason %s", + desc.InitializationErrorString()); + return pd::OpRegistry::CreateOp(desc); + }) + .def("infer_shape", &pd::OperatorBase::InferShape) + .def("run", &pd::OperatorBase::Run) + .def("outputs", [](const pd::OperatorPtr& op) { return op->outputs_; }); return m.ptr(); } diff --git a/proto/ModelConfig.proto b/proto/ModelConfig.proto index 37cd16c79890738f6d8966579e15686c653d4df3..83f72c137bdf5e55f28be908321bd2ccd6c906fe 100644 --- a/proto/ModelConfig.proto +++ b/proto/ModelConfig.proto @@ -472,10 +472,16 @@ message LayerConfig { // blank label used in ctc loss optional uint32 blank = 52 [default = 0]; - // stride parameter for seqlastins layer, AverageLayer, MaxLayer, which + // stride parameter for seqlastins layer, AverageLayer, MaxLayer, which // controls the scope of pooling operation. can be set > 0. // leave empty or set to -1 to disable this stride pooling. optional int32 seq_pool_stride = 53 [default = -1]; + + // for crop layer + optional int32 axis = 54 [default = 2]; + repeated uint32 offset = 55; + repeated uint32 shape = 56; + } message EvaluatorConfig { diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index 7190f0e8c9cfe7b96ba6d669a2058b4e3f139f68..fc112f1327f5ad5f1bdd04873394b1fa0e761e29 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -1998,6 +1998,23 @@ class PadLayer(LayerBase): self.config.size = out_ch * out_h * out_w +@config_layer('crop') +class CropLayer(LayerBase): + def __init__(self, name, inputs, axis, offset, shape, **xargs): + super(CropLayer, self).__init__(name, 'crop', 0, inputs=inputs, **xargs) + self.config.axis = axis + self.config.offset.extend(offset) + self.config.shape.extend(shape) + + # get channel, width and height from input_0 layer + input_layer = self.get_input_layer(0) + image_conf = self.config.inputs[0].image_conf + image_conf.img_size = input_layer.width + image_conf.img_size_y = input_layer.height + image_conf.channels = input_layer.size / (input_layer.width * + input_layer.height) + + @config_layer('batch_norm') class BatchNormLayer(LayerBase): layer_type = 'batch_norm' diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index 78aa0778f8d1dca9fae82f0411be5a00e636cbc9..fdb6f83f2ba510232714fb8a9c7c1af837a753ff 100755 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -127,6 +127,7 @@ __all__ = [ 'dropout_layer', 'prelu_layer', 'gated_unit_layer', + 'crop_layer', ] @@ -218,6 +219,7 @@ class LayerType(object): SMOOTH_L1 = 'smooth_l1' PRELU = 'prelu' + CROP_LAYER = 'crop' @staticmethod def is_layer_type(type_name): @@ -5970,3 +5972,52 @@ def gated_unit_layer(input, name="%s_gated_act" % name, input=dotmul_operator(input_proj, gate), layer_attr=layer_attr) + + +@wrap_name_default() +@layer_support() +def crop_layer(input, offset, axis=2, shape=None, name=None, layer_attr=None): + """ + The crop layer crops images by offset and shape. User can set crop shape by + args 'shape' explicitly or by reference input layer. + + The example usage is: + + .. code-block:: python + crop = crop_layer(input=[image_input, reference_input], axis=2, offset=[2, 3]) + + :param input: The input layer.If two inputs were setted, + the second input will be regarded as reference input + :type input: LayerOutput or Sequence + :param offset: The crop offset + :type offset: Sequence + :param axis: start axis to be cropped. To image input layer: + - 0: batch size + - 1: channels + - 2: height + - 3: width + :type partial_sum: int + :param shape: The shape to be cropped. Default is None. + :type shape: Sequence | None + :param name: Name of this layer. + :type name: basestring + :return: LayerOutput object. + :rtype: LayerOutput + """ + if isinstance(input, LayerOutput): + input = [input] + else: + assert isinstance(input, collections.Sequence) + l = Layer( + inputs=[x.name for x in input], + axis=axis, + offset=offset, + shape=shape, + name=name, + type=LayerType.CROP_LAYER, + **ExtraLayerAttribute.to_kwargs(layer_attr)) + return LayerOutput( + name=name, + layer_type=LayerType.CROP_LAYER, + parents=input, + size=l.config.size) diff --git a/python/paddle/trainer_config_helpers/tests/configs/test_crop.py b/python/paddle/trainer_config_helpers/tests/configs/test_crop.py new file mode 100644 index 0000000000000000000000000000000000000000..8314a7e9a5586647c70ff010156817110919c72b --- /dev/null +++ b/python/paddle/trainer_config_helpers/tests/configs/test_crop.py @@ -0,0 +1,21 @@ +from paddle.trainer_config_helpers import * + +settings(batch_size=1000, learning_rate=1e-5) + +data = data_layer(name='data', size=2016, height=48, width=42) +refernce_data = data_layer(name='data', size=768, height=16, width=16) + +conv = img_conv_layer( + input=data, + filter_size=3, + num_channels=1, + num_filters=16, + padding=1, + act=LinearActivation(), + bias_attr=True) + +pool = img_pool_layer(input=conv, pool_size=2, stride=2, pool_type=MaxPooling()) + +crop = crop_layer(input=[pool, refernce_data], axis=2) + +outputs(pad) diff --git a/python/paddle/v2/framework/create_op_creation_methods.py b/python/paddle/v2/framework/create_op_creation_methods.py index c2a7ae7692b08762ffbc91726be7bfa90e8ddedb..7248c3f52a9902e8c08ac2f1405801a5710459e5 100644 --- a/python/paddle/v2/framework/create_op_creation_methods.py +++ b/python/paddle/v2/framework/create_op_creation_methods.py @@ -217,6 +217,10 @@ def create_op_creation_method(op_proto): return core.Operator.create(opdesc.SerializeToString()) __impl__.__doc__ = get_docstring_from_op_proto(op_proto) + __impl__.all_input_args = [var.name for var in op_proto.inputs] + __impl__.all_output_args = [var.name for var in op_proto.outputs] + __impl__.all_attr_args = [attr.name for attr in op_proto.attrs] + return __impl__ diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/framework/tests/CMakeLists.txt index 4ce2bef6fcc4b8ddf5a6de3809a1891bce590aab..ec076e40c9312fee7f3ba030dc69208069fd45a8 100644 --- a/python/paddle/v2/framework/tests/CMakeLists.txt +++ b/python/paddle/v2/framework/tests/CMakeLists.txt @@ -1,3 +1,3 @@ add_python_test(test_framework test_protobuf.py test_scope.py test_default_scope_funcs.py test_op_creation_methods.py - test_tensor.py) + test_tensor.py test_fc_op.py test_add_two_op.py test_sgd_op.py) diff --git a/python/paddle/v2/framework/tests/op_test_util.py b/python/paddle/v2/framework/tests/op_test_util.py new file mode 100644 index 0000000000000000000000000000000000000000..b1fa12cc89fa724994ea482ab0a3d78c03a9cdf0 --- /dev/null +++ b/python/paddle/v2/framework/tests/op_test_util.py @@ -0,0 +1,62 @@ +import paddle.v2.framework.core as core +import unittest +import numpy +import paddle.v2.framework.create_op_creation_methods as creation + + +class OpTestMeta(type): + """ + Operator Test ClassMeta. + + It injects `test_all` method into user's OperatorTest class, to make Python + unittest module run that method. + + The `test_all` read what value is stored in `self`. It use self's values to + create and run a operator, and check whether that op is OK or not. + + See `test_add_two_op` for example usage. + """ + + def __new__(cls, name, bases, attrs): + obj = super(OpTestMeta, cls).__new__(cls, name, bases, attrs) + + def test_all(self): + func = getattr(creation.op_creations, self.type, None) + self.assertIsNotNone(func) + + scope = core.Scope(None) + kwargs = dict() + + for in_name in func.all_input_args: + if hasattr(self, in_name): + kwargs[in_name] = in_name + var = scope.create_var(in_name).get_tensor() + arr = getattr(self, in_name) + var.set_dims(arr.shape) + var.set(arr) + else: + kwargs[in_name] = "@EMPTY@" + + for out_name in func.all_output_args: + if hasattr(self, out_name): + kwargs[out_name] = out_name + scope.create_var(out_name).get_tensor() + + for attr_name in func.all_attr_args: + if hasattr(self, attr_name): + kwargs[attr_name] = getattr(self, attr_name) + + op = func(**kwargs) + + op.infer_shape(scope) + + ctx = core.DeviceContext.cpu_context() + op.run(scope, ctx) + + for out_name in func.all_output_args: + actual = numpy.array(scope.get_var(out_name).get_tensor()) + expect = getattr(self, out_name) + numpy.testing.assert_almost_equal(actual, expect) + + obj.test_all = test_all + return obj diff --git a/python/paddle/v2/framework/tests/test_add_two_op.py b/python/paddle/v2/framework/tests/test_add_two_op.py new file mode 100644 index 0000000000000000000000000000000000000000..a06d7a78ecf838a49e5f2808d3686c6b92faa8ce --- /dev/null +++ b/python/paddle/v2/framework/tests/test_add_two_op.py @@ -0,0 +1,17 @@ +import unittest +from op_test_util import OpTestMeta +import numpy + + +class TestAddOp(unittest.TestCase): + __metaclass__ = OpTestMeta + + def setUp(self): + self.type = "add_two" + self.X = numpy.random.random((342, 345)).astype("float32") + self.Y = numpy.random.random((342, 345)).astype("float32") + self.Out = self.X + self.Y + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_fc_op.py b/python/paddle/v2/framework/tests/test_fc_op.py new file mode 100644 index 0000000000000000000000000000000000000000..59e7e61249e2a7d49a17e5d87209f03b8f35f730 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_fc_op.py @@ -0,0 +1,43 @@ +import paddle.v2.framework.core as core +import unittest +import numpy +import paddle.v2.framework.create_op_creation_methods as creation + + +class TestFc(unittest.TestCase): + def test_fc(self): + scope = core.Scope(None) + x = scope.create_var("X") + x_tensor = x.get_tensor() + x_tensor.set_dims([1000, 784]) + x_tensor.alloc_float() + + w = scope.create_var("W") + w_tensor = w.get_tensor() + w_tensor.set_dims([784, 100]) + w_tensor.alloc_float() + + w_tensor.set(numpy.random.random((784, 100)).astype("float32")) + + # Set a real numpy array here. + # x_tensor.set(numpy.array([])) + + op = creation.op_creations.fc(X="X", Y="Y", W="W") + + for out in op.outputs(): + if scope.get_var(out) is None: + scope.create_var(out).get_tensor() + + tensor = scope.get_var("Y").get_tensor() + op.infer_shape(scope) + self.assertEqual([1000, 100], tensor.shape()) + + ctx = core.DeviceContext.cpu_context() + + op.run(scope, ctx) + + # After complete all ops, check Y is expect or not. + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_sgd_op.py b/python/paddle/v2/framework/tests/test_sgd_op.py new file mode 100644 index 0000000000000000000000000000000000000000..405d73b224fa153e50b4ec408a921f2bdaab46aa --- /dev/null +++ b/python/paddle/v2/framework/tests/test_sgd_op.py @@ -0,0 +1,18 @@ +import unittest +import numpy +from op_test_util import OpTestMeta + + +class TestSGD(unittest.TestCase): + __metaclass__ = OpTestMeta + + def setUp(self): + self.type = "sgd" + self.param = numpy.random.random((342, 345)).astype("float32") + self.grad = numpy.random.random((342, 345)).astype("float32") + self.learning_rate = 0.1 + self.param_out = self.param - self.learning_rate * self.grad + + +if __name__ == "__main__": + unittest.main()