diff --git a/cmake/generic.cmake b/cmake/generic.cmake index e42e75c12ab1e5133f5ecbdb90ef26e3f8df5133..534be0abe246ac70950d85ad05441825c8ca768a 100644 --- a/cmake/generic.cmake +++ b/cmake/generic.cmake @@ -290,8 +290,22 @@ function(go_library TARGET_NAME) set(${TARGET_NAME}_LIB_NAME "${CMAKE_STATIC_LIBRARY_PREFIX}${TARGET_NAME}${CMAKE_STATIC_LIBRARY_SUFFIX}" CACHE STRING "output library name for target ${TARGET_NAME}") endif() - # Add dummy code to support `make target_name` under Terminal Command set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}_dummy.c) + + # This custom command will always run since it depends on a not + # existing file. + add_custom_command( + OUTPUT dummy_rebulid_${TARGET_NAME} + COMMAND cmake -E touch ${dummyfile} + ) + # Create a custom target that depends on the custom command output + # file, so the custom command can be referenced as a dependency by + # `add_dependencies`. + add_custom_target(rebuild_${TARGET_NAME} + DEPENDS dummy_rebulid_${TARGET_NAME} + ) + + # Add dummy code to support `make target_name` under Terminal Command file(WRITE ${dummyfile} "const char * dummy = \"${dummyfile}\";") if (go_library_SHARED OR go_library_shared) add_library(${TARGET_NAME} SHARED ${dummyfile}) @@ -302,6 +316,12 @@ function(go_library TARGET_NAME) add_dependencies(${TARGET_NAME} ${go_library_DEPS}) endif(go_library_DEPS) + # The "source file" of the library is `${dummyfile}` which never + # change, so the target will never rebuild. Make the target depends + # on the custom command that touches the library "source file", so + # rebuild will always happen. + add_dependencies(${TARGET_NAME} rebuild_${TARGET_NAME}) + set(${TARGET_NAME}_LIB_PATH "${CMAKE_CURRENT_BINARY_DIR}/${${TARGET_NAME}_LIB_NAME}" CACHE STRING "output library path for target ${TARGET_NAME}") file(GLOB GO_SOURCE RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.go") diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index a00b9c81906fb1194b51efc50e6255f092875281..760d84e51e7473d359a415e4790251db3d139ab2 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -1,26 +1,25 @@ # ddim lib -cc_library(enforce SRCS enforce.cc DEPS glog) -cc_test(enforce_test SRCS enforce_test.cc DEPS enforce) cc_library(ddim SRCS ddim.cc DEPS eigen3) cc_test(ddim_test SRCS ddim_test.cc DEPS ddim) nv_test(dim_test SRCS dim_test.cu DEPS ddim) -cc_library(tensor SRCS tensor.cc DEPS ddim place enforce paddle_memory) +cc_library(tensor SRCS tensor.cc DEPS ddim place paddle_memory) cc_test(tensor_test SRCS tensor_test.cc DEPS tensor) cc_test(eigen_test SRCS eigen_test.cc DEPS tensor) cc_test(variable_test SRCS variable_test.cc) cc_test(scope_test SRCS scope_test.cc) + proto_library(attr_type SRCS attr_type.proto) proto_library(op_proto SRCS op_proto.proto DEPS attr_type) -cc_test(op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf) proto_library(op_desc SRCS op_desc.proto DEPS attr_type) +cc_test(op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf) cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf) cc_library(operator SRCS operator.cc DEPS op_desc device_context tensor) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry) -cc_library(op_registry SRCS op_registry.cc DEPS op_proto op_desc enforce) +cc_library(op_registry SRCS op_registry.cc DEPS op_proto op_desc) cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry operator) py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto) diff --git a/paddle/framework/attr_checker.h b/paddle/framework/attr_checker.h index f2d88f3cb00e20f548a5cd412b515e843491a76d..ea5614a45f3a77a851358aff80abbc276c9972ba 100644 --- a/paddle/framework/attr_checker.h +++ b/paddle/framework/attr_checker.h @@ -6,7 +6,7 @@ #include #include #include -#include "paddle/framework/enforce.h" +#include "paddle/platform/enforce.h" namespace paddle { namespace framework { diff --git a/paddle/framework/ddim.cc b/paddle/framework/ddim.cc index d2ef85afe55e640a17b8c957bac61d175e69ff3f..545c1dcc2a1682839d90194002fdbb748d85e808 100644 --- a/paddle/framework/ddim.cc +++ b/paddle/framework/ddim.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/framework/ddim.h" -#include "paddle/framework/enforce.h" +#include "paddle/platform/enforce.h" namespace paddle { namespace framework { diff --git a/paddle/framework/ddim.h b/paddle/framework/ddim.h index 06c4c583b3afa9b472561e5d8166cef3398c57f4..9fcc657edcd5459d0a42a64d708603a4bcd53cf0 100644 --- a/paddle/framework/ddim.h +++ b/paddle/framework/ddim.h @@ -19,7 +19,7 @@ limitations under the License. */ #include #include #include "paddle/framework/dim.h" -#include "paddle/framework/enforce.h" +#include "paddle/platform/enforce.h" #include "unsupported/Eigen/CXX11/Tensor" namespace paddle { diff --git a/paddle/framework/enforce.cc b/paddle/framework/enforce.cc deleted file mode 100644 index 644930ff989bb8935f37642c117084f580379bd7..0000000000000000000000000000000000000000 --- a/paddle/framework/enforce.cc +++ /dev/null @@ -1,15 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ - -#include "paddle/framework/enforce.h" diff --git a/paddle/framework/enforce.h b/paddle/framework/enforce.h deleted file mode 100644 index ffce8148e9516a5720757c87685ff6bd2937977c..0000000000000000000000000000000000000000 --- a/paddle/framework/enforce.h +++ /dev/null @@ -1,75 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include -#include -#include -#include - -namespace paddle { -namespace framework { - -/** - * @brief Enforce exception. Inherits std::exception - * - * All enforce condition not met, will throw an EnforceNotMet exception. - */ -class EnforceNotMet : public std::exception { - public: - EnforceNotMet(const std::string& msg, const char* file, int fileline) { - std::ostringstream sout; - sout << msg << " at [" << file << ":" << fileline << "];"; - all_msg_ = sout.str(); - } - - const char* what() const noexcept override { return all_msg_.c_str(); } - - private: - std::string all_msg_; -}; - -// From https://stackoverflow.com/questions/30130930/ -// __buildin_expect is in C++ 11 standard. Since the condition which enforced -// should be true in most situation, it will make the compiler generate faster -// code by adding `UNLIKELY` macro. -#define UNLIKELY(condition) __builtin_expect(static_cast(condition), 0) - -/** - * @brief Throw a EnforceNotMet exception, automatically filled __FILE__ & - * __LINE__ - * - * This macro take __VA_ARGS__, user can pass any type if that type can - * serialize to std::ostream - */ -#define PADDLE_THROW(...) \ - do { \ - throw ::paddle::framework::EnforceNotMet( \ - ::paddle::string::Sprintf(__VA_ARGS__), __FILE__, __LINE__); \ - } while (0) - -/** - * @brief Enforce a condition, otherwise throw an EnforceNotMet - */ -#ifdef NDEBUG -#define PADDLE_ENFORCE(condition, ...) \ - do { \ - if (UNLIKELY(!(condition))) { \ - PADDLE_THROW(__VA_ARGS__); \ - } \ - } while (0) -#else -#define PADDLE_ENFORCE(condition, ...) \ - CHECK(condition) << ::paddle::string::Sprintf(__VA_ARGS__); -#endif - -} // namespace framework -} // namespace paddle diff --git a/paddle/framework/net_op_test.cc b/paddle/framework/net_op_test.cc index f5e1c22400a73c3aa09839ef9654f87def99bc77..e814a7e43d7ae7af0974d1a7c8b072bde5ba0238 100644 --- a/paddle/framework/net_op_test.cc +++ b/paddle/framework/net_op_test.cc @@ -63,5 +63,5 @@ TEST(OpKernel, all) { ASSERT_EQ(2, infer_shape_cnt); ASSERT_EQ(2, run_cnt); - ASSERT_THROW(net->AddOp(op2), paddle::framework::EnforceNotMet); + ASSERT_THROW(net->AddOp(op2), std::runtime_error); } diff --git a/paddle/framework/op_registry_test.cc b/paddle/framework/op_registry_test.cc index d3a51a361aa56b26b87d79057f6700bd87264ca4..32a7e88a894fb61a460443b7d593a6cf44bc98c5 100644 --- a/paddle/framework/op_registry_test.cc +++ b/paddle/framework/op_registry_test.cc @@ -91,7 +91,7 @@ TEST(OpRegistry, IllegalAttr) { try { paddle::framework::OperatorPtr op __attribute__((unused)) = paddle::framework::OpRegistry::CreateOp(op_desc); - } catch (paddle::framework::EnforceNotMet err) { + } catch (std::runtime_error& err) { caught = true; std::string msg = "larger_than check fail"; const char* err_msg = err.what(); @@ -138,7 +138,7 @@ TEST(OpRegistry, CustomChecker) { try { paddle::framework::OperatorPtr op __attribute__((unused)) = paddle::framework::OpRegistry::CreateOp(op_desc); - } catch (paddle::framework::EnforceNotMet err) { + } catch (std::runtime_error& err) { caught = true; std::string msg = "Attribute 'test_attr' is required!"; const char* err_msg = err.what(); @@ -157,7 +157,7 @@ TEST(OpRegistry, CustomChecker) { try { paddle::framework::OperatorPtr op __attribute__((unused)) = paddle::framework::OpRegistry::CreateOp(op_desc); - } catch (paddle::framework::EnforceNotMet err) { + } catch (std::runtime_error& err) { caught = true; std::string msg = "'test_attr' must be even!"; const char* err_msg = err.what(); @@ -196,7 +196,7 @@ TEST(ProtoMaker, DuplicatedAttr) { pd::OpProto op_proto; pd::OpAttrChecker op_checker; auto proto_maker = TestAttrProtoMaker(&op_proto, &op_checker); - ASSERT_THROW(proto_maker.Validate(), paddle::framework::EnforceNotMet); + ASSERT_THROW(proto_maker.Validate(), std::runtime_error); } class TestInOutProtoMaker : public pd::OpProtoAndCheckerMaker { @@ -212,5 +212,5 @@ TEST(ProtoMaker, DuplicatedInOut) { pd::OpProto op_proto; pd::OpAttrChecker op_checker; auto proto_maker = TestInOutProtoMaker(&op_proto, &op_checker); - ASSERT_THROW(proto_maker.Validate(), paddle::framework::EnforceNotMet); + ASSERT_THROW(proto_maker.Validate(), std::runtime_error); } diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index 39e0f9f7103dae4f710a80e8c33702094e1ab590..93c6fad5d3d9f3de100d30161e6e438eb43816a2 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -19,8 +19,8 @@ limitations under the License. */ #include #include #include "paddle/framework/ddim.h" -#include "paddle/framework/enforce.h" #include "paddle/memory/memory.h" +#include "paddle/platform/enforce.h" #include "paddle/platform/place.h" #include "unsupported/Eigen/CXX11/Tensor" diff --git a/paddle/framework/tensor_test.cc b/paddle/framework/tensor_test.cc index 84c6f0cf6558819440458688ca52b06c1cf11dd0..8a7cbbd0de6fd6aaafa8649abb8628e971bc49c1 100644 --- a/paddle/framework/tensor_test.cc +++ b/paddle/framework/tensor_test.cc @@ -33,7 +33,7 @@ TEST(Tensor, DataAssert) { bool caught = false; try { src_tensor.data(); - } catch (paddle::framework::EnforceNotMet err) { + } catch (std::runtime_error& err) { caught = true; std::string msg = "Tenosr holds no memory. Call Tensor::mutable_data first."; @@ -107,7 +107,7 @@ TEST(Tensor, ShareDataFrom) { bool caught = false; try { dst_tensor.ShareDataFrom(src_tensor); - } catch (EnforceNotMet err) { + } catch (std::runtime_error& err) { caught = true; std::string msg = "Tenosr holds no memory. Call Tensor::mutable_data first."; diff --git a/paddle/memory/detail/system_allocator.cc b/paddle/memory/detail/system_allocator.cc index 1579174b1a6ff08824629d833d01411cff651f48..f61e67a32906083881dd7f47433521876be9b355 100644 --- a/paddle/memory/detail/system_allocator.cc +++ b/paddle/memory/detail/system_allocator.cc @@ -14,7 +14,7 @@ limitations under the License. */ #include "paddle/memory/detail/system_allocator.h" #include "paddle/platform/assert.h" -#include "paddle/platform/error.h" +#include "paddle/platform/enforce.h" #include "paddle/platform/gpu_info.h" #include // for malloc and free @@ -128,8 +128,7 @@ void GPUAllocator::Free(void* p, size_t size, size_t index) { // process is terminating, in which case we don't care if // cudaFree succeeds. if (err != cudaErrorCudartUnloading) { - platform::throw_on_error(err, - "cudaFree{Host} failed in GPUAllocator::Free."); + PADDLE_ENFORCE(err, "cudaFree{Host} failed in GPUAllocator::Free."); } } diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index bc64bfd7ec2ed27835e5a3f9135343aeb3d4a580..a37720e5093342f5e02bd9a15a3099de434d6396 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -51,3 +51,5 @@ op_library(softmax_op SRCS softmax_op.cc softmax_op.cu) op_library(fc_op SRCS fc_op.cc DEPS mul_op rowwise_add_op sigmoid_op softmax_op net) + +op_library(sgd_op SRCS sgd_op.cc sgd_op.cu) diff --git a/paddle/operators/sgd_op.cc b/paddle/operators/sgd_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..04df87a3add2af7daa127a072f7b690f6cf94327 --- /dev/null +++ b/paddle/operators/sgd_op.cc @@ -0,0 +1,61 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/sgd_op.h" +#include "paddle/framework/op_registry.h" +#include "paddle/framework/tensor.h" + +namespace paddle { +namespace operators { + +class SGDOp : public framework::OperatorWithKernel { +protected: + void InferShape( + const std::vector &inputs, + const std::vector &outputs) const override { + PADDLE_ENFORCE(inputs.size() == 2, "Input size of SGDOp must be two"); + PADDLE_ENFORCE(outputs.size() == 1, "Output size of SGDOp must be one"); + PADDLE_ENFORCE(inputs[0] != nullptr, "inputs[0] mast be set"); + PADDLE_ENFORCE(inputs[1] != nullptr, "inputs[1] mast be set"); + PADDLE_ENFORCE(outputs[0] != nullptr, "outputs[0] mast be set"); + PADDLE_ENFORCE(inputs[0]->dims() == inputs[1]->dims(), + "Two input of SGD Op's dimension must be same."); + outputs[0]->set_dims(inputs[0]->dims()); + } +}; + +class SGDOpMaker : public framework::OpProtoAndCheckerMaker { +public: + SGDOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : framework::OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("param", "input parameter"); + AddInput("grad", "input gradient"); + AddOutput("param_out", "output parameter"); + AddAttr("learning_rate", "learning rate of sgd"); + AddComment(R"DOC( + +Simplest sgd algorithm. + +param_out = param - learning_rate * grad; + +)DOC"); + } +}; +} // namespace operators +} // namespace paddle + +REGISTER_OP(sgd, paddle::operators::SGDOp, paddle::operators::SGDOpMaker); +typedef paddle::operators::SGDOpKernel<::paddle::platform::CPUPlace, float> + SGDOpKernel_CPU_float; +REGISTER_OP_CPU_KERNEL(sgd, SGDOpKernel_CPU_float); diff --git a/paddle/operators/sgd_op.cu b/paddle/operators/sgd_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..400425db10896e3970fc7468e34aba596a536184 --- /dev/null +++ b/paddle/operators/sgd_op.cu @@ -0,0 +1,5 @@ +#include "paddle/operators/sgd_op.h" +#include "paddle/framework/op_registry.h" + +typedef paddle::operators::SGDOpKernel<::paddle::platform::GPUPlace, float> SGDOpKernel_GPU_float; +REGISTER_OP_GPU_KERNEL(sgd, SGDOpKernel_GPU_float); \ No newline at end of file diff --git a/paddle/operators/sgd_op.h b/paddle/operators/sgd_op.h new file mode 100644 index 0000000000000000000000000000000000000000..4b2d214618e5c7c15695bd66604139d805255c47 --- /dev/null +++ b/paddle/operators/sgd_op.h @@ -0,0 +1,42 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "glog/logging.h" +#include "paddle/framework/eigen.h" +#include "paddle/framework/operator.h" + +namespace paddle { +namespace operators { + +template +class SGDOpKernel : public framework::OpKernel { +public: + void Compute(const framework::KernelContext& ctx) const override { + auto param = ctx.Input("param")->Get(); + auto grad = ctx.Input("grad")->Get(); + auto* param_out = ctx.Output(0)->GetMutable(); + float lr = ctx.op_.GetAttr("learning_rate"); + + param_out->mutable_data(ctx.GetPlace()); + + framework::EigenVector::Flatten(*param_out) + .device(*(ctx.GetEigenDevice())) = + framework::EigenVector::Flatten(param) - + lr * framework::EigenVector::Flatten(grad); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/sgd_op_test.cc b/paddle/operators/sgd_op_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..75137259f5e608b259b073101353e5818bb17c92 --- /dev/null +++ b/paddle/operators/sgd_op_test.cc @@ -0,0 +1,22 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +USE_OP(sgd); +TEST(SGDOp, GetOpProto) { + auto& protos = paddle::framework::OpRegistry::protos(); + auto it = protos.find("sgd"); + ASSERT_NE(it, protos.end()); +} diff --git a/paddle/platform/CMakeLists.txt b/paddle/platform/CMakeLists.txt index 6ac4035c0f863c5f63d17b523a7a8be668ff3da0..bd77bb7daa50e0b273f110624ddf6f4b79a3ceab 100644 --- a/paddle/platform/CMakeLists.txt +++ b/paddle/platform/CMakeLists.txt @@ -8,6 +8,8 @@ cc_test(place_test SRCS place_test.cc DEPS place glog gflags) add_subdirectory(dynload) +cc_test(enforce_test SRCS enforce_test.cc) + IF(WITH_GPU) set(GPU_CTX_DEPS dynload_cuda dynamic_loader) ELSE() diff --git a/paddle/platform/cpu_info.cc b/paddle/platform/cpu_info.cc index dfab391cfbe1f04bc2a998233f7e7909579ca72b..78e1fa9df56b1623bfd9a53c6a37524d29648afc 100644 --- a/paddle/platform/cpu_info.cc +++ b/paddle/platform/cpu_info.cc @@ -22,7 +22,6 @@ limitations under the License. */ #endif #include "gflags/gflags.h" -#include "paddle/platform/error.h" DEFINE_double(fraction_of_cpu_memory_to_use, 1, "Default use 100% of CPU memory for PaddlePaddle," diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index f226a75c20b7a75e5f884cd158d139ebb8b34e47..fe6f13e399a78f9e5230ae52b0f67ab465af373b 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -11,12 +11,13 @@ limitations under the License. */ #pragma once -#include "paddle/framework/enforce.h" +#include "paddle/platform/enforce.h" +#include "paddle/platform/place.h" + #ifndef PADDLE_ONLY_CPU #include "paddle/platform/dynload/cublas.h" #include "paddle/platform/dynload/cudnn.h" #include "paddle/platform/dynload/curand.h" -#include "paddle/platform/error.h" #include "paddle/platform/gpu_info.h" #define EIGEN_USE_GPU #endif @@ -71,8 +72,7 @@ class CUDADeviceContext : public DeviceContext { public: explicit CUDADeviceContext(const GPUPlace gpu_place) : gpu_place_(gpu_place) { GPUPlaceGuard guard(gpu_place_); - paddle::platform::throw_on_error(cudaStreamCreate(&stream_), - "cudaStreamCreate failed"); + PADDLE_ENFORCE(cudaStreamCreate(&stream_), "cudaStreamCreate failed"); eigen_stream_.reset(new Eigen::CudaStreamDevice(&stream_)); eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get())); } @@ -83,8 +83,8 @@ class CUDADeviceContext : public DeviceContext { } void Wait() { - paddle::platform::throw_on_error(cudaStreamSynchronize(stream_), - "cudaStreamSynchronize failed"); + PADDLE_ENFORCE(cudaStreamSynchronize(stream_), + "cudaStreamSynchronize failed"); } cudaStream_t stream() { return stream_; } @@ -94,12 +94,11 @@ class CUDADeviceContext : public DeviceContext { cublasHandle_t cublas_handle() { if (!blas_handle_) { GPUPlaceGuard guard(gpu_place_); - PADDLE_ENFORCE(paddle::platform::dynload::cublasCreate(&blas_handle_) == - CUBLAS_STATUS_SUCCESS, + PADDLE_ENFORCE(paddle::platform::dynload::cublasCreate(&blas_handle_), "cublasCreate failed"); - PADDLE_ENFORCE(paddle::platform::dynload::cublasSetStream( - blas_handle_, stream_) == CUBLAS_STATUS_SUCCESS, - "cublasSetStream failed"); + PADDLE_ENFORCE( + paddle::platform::dynload::cublasSetStream(blas_handle_, stream_), + "cublasSetStream failed"); } return blas_handle_; } @@ -107,12 +106,11 @@ class CUDADeviceContext : public DeviceContext { cudnnHandle_t cudnn_handle() { if (!dnn_handle_) { GPUPlaceGuard guard(gpu_place_); - PADDLE_ENFORCE(paddle::platform::dynload::cudnnCreate(&dnn_handle_) == - CUDNN_STATUS_SUCCESS, + PADDLE_ENFORCE(paddle::platform::dynload::cudnnCreate(&dnn_handle_), "cudnnCreate failed"); - PADDLE_ENFORCE(paddle::platform::dynload::cudnnSetStream( - dnn_handle_, stream_) == CUDNN_STATUS_SUCCESS, - "cudnnSetStream failed"); + PADDLE_ENFORCE( + paddle::platform::dynload::cudnnSetStream(dnn_handle_, stream_), + "cudnnSetStream failed"); } return dnn_handle_; } @@ -121,16 +119,15 @@ class CUDADeviceContext : public DeviceContext { if (!rand_generator_) { GPUPlaceGuard guard(gpu_place_); PADDLE_ENFORCE(paddle::platform::dynload::curandCreateGenerator( - &rand_generator_, CURAND_RNG_PSEUDO_DEFAULT) == - CURAND_STATUS_SUCCESS, + &rand_generator_, CURAND_RNG_PSEUDO_DEFAULT), "curandCreateGenerator failed"); PADDLE_ENFORCE( paddle::platform::dynload::curandSetPseudoRandomGeneratorSeed( - rand_generator_, random_seed_) == CURAND_STATUS_SUCCESS, + rand_generator_, random_seed_), "curandSetPseudoRandomGeneratorSeed failed"); - PADDLE_ENFORCE(paddle::platform::dynload::curandSetStream( - rand_generator_, stream_) == CURAND_STATUS_SUCCESS, - "curandSetStream failed"); + PADDLE_ENFORCE( + paddle::platform::dynload::curandSetStream(rand_generator_, stream_), + "curandSetStream failed"); } return rand_generator_; } @@ -138,26 +135,23 @@ class CUDADeviceContext : public DeviceContext { ~CUDADeviceContext() { Wait(); if (blas_handle_) { - PADDLE_ENFORCE(paddle::platform::dynload::cublasDestroy(blas_handle_) == - CUBLAS_STATUS_SUCCESS, + PADDLE_ENFORCE(paddle::platform::dynload::cublasDestroy(blas_handle_), "cublasDestroy failed"); } if (dnn_handle_) { - PADDLE_ENFORCE(paddle::platform::dynload::cudnnDestroy(dnn_handle_) == - CUDNN_STATUS_SUCCESS, + PADDLE_ENFORCE(paddle::platform::dynload::cudnnDestroy(dnn_handle_), "cudnnDestroy failed"); } if (rand_generator_) { - PADDLE_ENFORCE(paddle::platform::dynload::curandDestroyGenerator( - rand_generator_) == CURAND_STATUS_SUCCESS, - "curandDestroyGenerator failed"); + PADDLE_ENFORCE( + paddle::platform::dynload::curandDestroyGenerator(rand_generator_), + "curandDestroyGenerator failed"); } eigen_stream_.reset(); eigen_device_.reset(); - paddle::platform::throw_on_error(cudaStreamDestroy(stream_), - "cudaStreamDestroy failed"); + PADDLE_ENFORCE(cudaStreamDestroy(stream_), "cudaStreamDestroy failed"); } private: diff --git a/paddle/platform/dynload/dynamic_loader.cc b/paddle/platform/dynload/dynamic_loader.cc index dd914e006d54c423ffea56ffaaafe7dcba416361..ae9a0a982c73de05821579d22b7f9ad99f24a92b 100644 --- a/paddle/platform/dynload/dynamic_loader.cc +++ b/paddle/platform/dynload/dynamic_loader.cc @@ -19,7 +19,7 @@ limitations under the License. */ #include #include "gflags/gflags.h" #include "glog/logging.h" -#include "paddle/framework/enforce.h" +#include "paddle/platform/enforce.h" DEFINE_string(cudnn_dir, "", "Specify path for loading libcudnn.so. For instance, " diff --git a/paddle/platform/enforce.h b/paddle/platform/enforce.h new file mode 100644 index 0000000000000000000000000000000000000000..5d440dec48e7a4cba404bc297eca5a451a144d93 --- /dev/null +++ b/paddle/platform/enforce.h @@ -0,0 +1,141 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include + +#ifndef PADDLE_ONLY_CPU + +#include "paddle/platform/dynload/cublas.h" +#include "paddle/platform/dynload/cudnn.h" +#include "paddle/platform/dynload/curand.h" + +#include +#include +#include +#include +#include + +#endif // PADDLE_ONLY_CPU + +namespace paddle { +namespace platform { + +// Because most enforce conditions would evaluate to true, we can use +// __builtin_expect to instruct the C++ compiler to generate code that +// always forces branch prediction of true. +// This generates faster binary code. __builtin_expect is since C++11. +// For more details, please check https://stackoverflow.com/a/43870188/724872. +#define UNLIKELY(condition) __builtin_expect(static_cast(condition), 0) + +#ifndef PADDLE_ONLY_CPU + +template +inline void throw_on_error(cudaError_t e, const Args&... args) { + if (UNLIKELY(e)) { + // clang-format off + throw thrust::system_error( + e, thrust::cuda_category(), + string::Sprintf(args...) + + string::Sprintf(" at [%s:%s];", __FILE__, __LINE__)); + // clang-format on + } +} + +template +inline void throw_on_error(curandStatus_t stat, const Args&... args) { + if (stat != CURAND_STATUS_SUCCESS) { + // clang-format off + throw thrust::system_error( + cudaErrorLaunchFailure, thrust::cuda_category(), + string::Sprintf(args...) + + string::Sprintf(" at [%s:%s];", __FILE__, __LINE__)); + // clang-format on + } +} + +template +inline void throw_on_error(cudnnStatus_t stat, const Args&... args) { + if (stat == CUDNN_STATUS_SUCCESS) { + return; + } else { + // clang-format off + throw std::runtime_error( + platform::dynload::cudnnGetErrorString(stat) + + string::Sprintf(args...) + + string::Sprintf(" at [%s:%s];", __FILE__, __LINE__)); + // clang-format on + } +} + +template +inline void throw_on_error(cublasStatus_t stat, const Args&... args) { + std::string err; + if (stat == CUBLAS_STATUS_SUCCESS) { + return; + } else if (stat == CUBLAS_STATUS_NOT_INITIALIZED) { + err = "CUBLAS: not initialized, "; + } else if (stat == CUBLAS_STATUS_ALLOC_FAILED) { + err = "CUBLAS: alloc failed, "; + } else if (stat == CUBLAS_STATUS_INVALID_VALUE) { + err = "CUBLAS: invalid value, "; + } else if (stat == CUBLAS_STATUS_ARCH_MISMATCH) { + err = "CUBLAS: arch mismatch, "; + } else if (stat == CUBLAS_STATUS_MAPPING_ERROR) { + err = "CUBLAS: mapping error, "; + } else if (stat == CUBLAS_STATUS_EXECUTION_FAILED) { + err = "CUBLAS: execution failed, "; + } else if (stat == CUBLAS_STATUS_INTERNAL_ERROR) { + err = "CUBLAS: internal error, "; + } else if (stat == CUBLAS_STATUS_NOT_SUPPORTED) { + err = "CUBLAS: not supported, "; + } else if (stat == CUBLAS_STATUS_LICENSE_ERROR) { + err = "CUBLAS: license error, "; + } + throw std::runtime_error(err + string::Sprintf(args...) + + string::Sprintf(" at [%s:%s];", __FILE__, __LINE__)); +} + +#endif // PADDLE_ONLY_CPU + +template +inline void throw_on_error(int stat, const Args&... args) { + if (UNLIKELY(!(stat))) { + throw std::runtime_error( + string::Sprintf(args...) + + string::Sprintf(" at [%s:%s];", __FILE__, __LINE__)); + } +} + +#define PADDLE_THROW(...) \ + do { \ + throw std::runtime_error( \ + string::Sprintf(__VA_ARGS__) + \ + string::Sprintf(" at [%s:%s];", __FILE__, __LINE__)); \ + } while (0) + +/** + * @brief Enforce a condition, otherwise throw an EnforceNotMet + */ +#define PADDLE_ENFORCE(condition, ...) \ + do { \ + ::paddle::platform::throw_on_error(condition, __VA_ARGS__); \ + } while (0) + +} // namespace platform +} // namespace paddle diff --git a/paddle/framework/enforce_test.cc b/paddle/platform/enforce_test.cc similarity index 85% rename from paddle/framework/enforce_test.cc rename to paddle/platform/enforce_test.cc index f8da1a192f63a54324d80725c9d2f156fb11a481..d7152f81509a35e4ce36d5649e7d209f51e34b86 100644 --- a/paddle/framework/enforce_test.cc +++ b/paddle/platform/enforce_test.cc @@ -9,8 +9,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include -#include +#include "paddle/platform/enforce.h" +#include "gtest/gtest.h" TEST(ENFORCE, OK) { PADDLE_ENFORCE(true, "Enforce is ok %d now %f", 123, 0.345); @@ -23,13 +23,14 @@ TEST(ENFORCE, FAILED) { bool in_catch = false; try { PADDLE_ENFORCE(false, "Enforce is not ok %d at all", 123); - } catch (paddle::framework::EnforceNotMet err) { + } catch (const std::runtime_error& error) { + // your error handling code here in_catch = true; std::string msg = "Enforce is not ok 123 at all"; - const char* what = err.what(); + const char* what = error.what(); for (size_t i = 0; i < msg.length(); ++i) { ASSERT_EQ(what[i], msg[i]); } } ASSERT_TRUE(in_catch); -} \ No newline at end of file +} diff --git a/paddle/platform/error.h b/paddle/platform/error.h deleted file mode 100644 index 93424bb61096503a4843da7942853a113f612e3b..0000000000000000000000000000000000000000 --- a/paddle/platform/error.h +++ /dev/null @@ -1,87 +0,0 @@ -#pragma once - -#include -#include -#include - -#ifndef PADDLE_ONLY_CPU - -#include -#include -#include -#include -#include - -#endif // PADDLE_ONLY_CPU - -namespace paddle { -namespace platform { - -#ifndef PADDLE_ONLY_CPU - -inline void throw_on_error(cudaError_t e, const char* message) { - if (e) { - throw thrust::system_error(e, thrust::cuda_category(), message); - } -} - -inline void throw_on_error(curandStatus_t stat, const char* message) { - if (stat != CURAND_STATUS_SUCCESS) { - throw thrust::system_error(cudaErrorLaunchFailure, thrust::cuda_category(), - message); - } -} - -inline void throw_on_error(cudnnStatus_t stat, const char* message) { - std::stringstream ss; - if (stat == CUDNN_STATUS_SUCCESS) { - return; - } else { - ss << cudnnGetErrorString(stat); - ss << ", " << message; - throw std::runtime_error(ss.str()); - } -} - -inline void throw_on_error(cublasStatus_t stat, const char* message) { - std::stringstream ss; - if (stat == CUBLAS_STATUS_SUCCESS) { - return; - } else if (stat == CUBLAS_STATUS_NOT_INITIALIZED) { - ss << "CUBLAS: not initialized"; - } else if (stat == CUBLAS_STATUS_ALLOC_FAILED) { - ss << "CUBLAS: alloc failed"; - } else if (stat == CUBLAS_STATUS_INVALID_VALUE) { - ss << "CUBLAS: invalid value"; - } else if (stat == CUBLAS_STATUS_ARCH_MISMATCH) { - ss << "CUBLAS: arch mismatch"; - } else if (stat == CUBLAS_STATUS_MAPPING_ERROR) { - ss << "CUBLAS: mapping error"; - } else if (stat == CUBLAS_STATUS_EXECUTION_FAILED) { - ss << "CUBLAS: execution failed"; - } else if (stat == CUBLAS_STATUS_INTERNAL_ERROR) { - ss << "CUBLAS: internal error"; - } else if (stat == CUBLAS_STATUS_NOT_SUPPORTED) { - ss << "CUBLAS: not supported"; - } else if (stat == CUBLAS_STATUS_LICENSE_ERROR) { - ss << "CUBLAS: license error"; - } - ss << ", " << message; - throw std::runtime_error(ss.str()); -} - -inline void throw_on_error(cublasStatus_t stat) { - const char* message = ""; - throw_on_error(stat, message); -} - -#endif // PADDLE_ONLY_CPU - -inline void throw_on_error(int stat, const char* message) { - if (stat) { - throw std::runtime_error(message + (", stat = " + std::to_string(stat))); - } -} - -} // namespace platform -} // namespace paddle diff --git a/paddle/platform/gpu_info.cc b/paddle/platform/gpu_info.cc index a1383d3524aedf834c329425419b989d47668bea..cf9921e870d47fe77c0cca80828dbf2bb36ccda8 100644 --- a/paddle/platform/gpu_info.cc +++ b/paddle/platform/gpu_info.cc @@ -14,7 +14,7 @@ limitations under the License. */ #include "paddle/platform/gpu_info.h" #include "gflags/gflags.h" -#include "paddle/platform/error.h" +#include "paddle/platform/enforce.h" DEFINE_double(fraction_of_gpu_memory_to_use, 0.95, "Default use 95% of GPU memory for PaddlePaddle," @@ -25,7 +25,7 @@ namespace platform { int GetDeviceCount() { int count; - throw_on_error( + PADDLE_ENFORCE( cudaGetDeviceCount(&count), "cudaGetDeviceCount failed in paddle::platform::GetDeviceCount"); return count; @@ -33,19 +33,19 @@ int GetDeviceCount() { int GetCurrentDeviceId() { int device_id; - throw_on_error( + PADDLE_ENFORCE( cudaGetDevice(&device_id), "cudaGetDevice failed in paddle::platform::GetCurrentDeviceId"); return device_id; } void SetDeviceId(int id) { - throw_on_error(cudaSetDevice(id), + PADDLE_ENFORCE(cudaSetDevice(id), "cudaSetDevice failed in paddle::platform::SetDeviceId"); } void GpuMemoryUsage(size_t& available, size_t& total) { - throw_on_error(cudaMemGetInfo(&available, &total), + PADDLE_ENFORCE(cudaMemGetInfo(&available, &total), "cudaMemGetInfo failed in paddle::platform::GetMemoryUsage"); } diff --git a/paddle/pybind/CMakeLists.txt b/paddle/pybind/CMakeLists.txt index 29fb29c7c14f699e6114cc25c265ea8d85bce4d7..6354dd211d5d036e1b5971babaf624e8f847a92b 100644 --- a/paddle/pybind/CMakeLists.txt +++ b/paddle/pybind/CMakeLists.txt @@ -1,2 +1,2 @@ cc_library(paddle_pybind SHARED SRCS pybind.cc DEPS pybind python - add_op fc_op) + add_op fc_op sgd_op) diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 7e84550f770e8dba998ce7ff91b9d774acbffc3e..54707a2859693af4a80692bf5cebab59c43ffbc3 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -28,6 +28,7 @@ namespace pd = paddle::framework; USE_OP(add_two); USE_OP_WITHOUT_KERNEL(fc); +USE_OP(sgd); PYBIND11_PLUGIN(core) { py::module m("core", "C++ core of Paddle Paddle"); diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/framework/tests/CMakeLists.txt index f71009aa8569beae330b18171043d456b59bca8d..ec076e40c9312fee7f3ba030dc69208069fd45a8 100644 --- a/python/paddle/v2/framework/tests/CMakeLists.txt +++ b/python/paddle/v2/framework/tests/CMakeLists.txt @@ -1,3 +1,3 @@ add_python_test(test_framework test_protobuf.py test_scope.py test_default_scope_funcs.py test_op_creation_methods.py - test_tensor.py test_fc_op.py test_add_two_op.py) + test_tensor.py test_fc_op.py test_add_two_op.py test_sgd_op.py) diff --git a/python/paddle/v2/framework/tests/test_sgd_op.py b/python/paddle/v2/framework/tests/test_sgd_op.py new file mode 100644 index 0000000000000000000000000000000000000000..405d73b224fa153e50b4ec408a921f2bdaab46aa --- /dev/null +++ b/python/paddle/v2/framework/tests/test_sgd_op.py @@ -0,0 +1,18 @@ +import unittest +import numpy +from op_test_util import OpTestMeta + + +class TestSGD(unittest.TestCase): + __metaclass__ = OpTestMeta + + def setUp(self): + self.type = "sgd" + self.param = numpy.random.random((342, 345)).astype("float32") + self.grad = numpy.random.random((342, 345)).astype("float32") + self.learning_rate = 0.1 + self.param_out = self.param - self.learning_rate * self.grad + + +if __name__ == "__main__": + unittest.main()