diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index 5874ef2f1f332cc9e8c8d074b6bcf577c55fcda0..196a0837d530fc9100893e8811d4674602b98a85 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -30,7 +30,7 @@ using DeviceContext = platform::DeviceContext; class EmptyOp : public OperatorBase { public: - DEFINE_OPERATOR_CTOR(EmptyOp, OperatorBase); + using OperatorBase::OperatorBase; void InferShape(const Scope &scope) const override {} void Run(const Scope &scope, const DeviceContext &dev_ctx) const override {} }; @@ -79,8 +79,9 @@ class NoGradOpMaker : public OpProtoAndCheckerMaker { class FcOp : public operators::NetOp { public: - DEFINE_OPERATOR_CTOR(FcOp, operators::NetOp) - void Init() override { + FcOp(const std::string &type, const VarNameMap &inputs, + const VarNameMap &outputs, const AttributeMap &attrs) + : NetOp(type, inputs, outputs, attrs) { AddOp(OpRegistry::CreateOp("mul", {{"X", {Input("X")}}, {"Y", {Input("W")}}}, {{"Out", {Output("mul_result")}}}, {})); diff --git a/paddle/framework/grad_op_builder.cc b/paddle/framework/grad_op_builder.cc index 048864c7004872df7fe6336675b4e3012f41709a..1833a5463aac7620472e1508f8ef75a894fc02c6 100644 --- a/paddle/framework/grad_op_builder.cc +++ b/paddle/framework/grad_op_builder.cc @@ -19,14 +19,12 @@ namespace paddle { namespace framework { enum class OpArgType { IN, OUT }; -using VarNameMap = OperatorBase::VarNameMap; - -static VarNameMap TransOpArg(const OperatorBase* src_op, - const OpArgType& src_type, - const OpArgType& dst_type, bool is_grad) { +static void TransOpArg(const OperatorBase* src_op, + OperatorBase::VarNameMap* vars, + const OpArgType& src_type, bool is_grad) { const auto& src_inout = - src_type == OpArgType::IN ? src_op->Inputs() : src_op->Outputs(); - VarNameMap dst_inout; + src_type == OpArgType::IN ? src_op->inputs_ : src_op->outputs_; + auto& dst_inout = *vars; const OpProto& proto = OpProtos().at(src_op->Type()); const auto& src_arg_list = @@ -45,20 +43,22 @@ static VarNameMap TransOpArg(const OperatorBase* src_op, } OperatorBase* BuildGradOp(const OperatorBase* op) { - std::string grad_op_type = OpRegistry::grad_ops().at(op->Type()); - auto I = TransOpArg(op, OpArgType::IN, OpArgType::IN, false); // I - auto O = TransOpArg(op, OpArgType::OUT, OpArgType::IN, false); // O - auto OG = TransOpArg(op, OpArgType::OUT, OpArgType::IN, true); // OG - auto IG = TransOpArg(op, OpArgType::IN, OpArgType::OUT, true); // IG - // TODO(merge I/O/OG) - VarNameMap GradIn; - GradIn.insert(I.begin(), I.end()); - GradIn.insert(O.begin(), O.end()); - GradIn.insert(OG.begin(), OG.end()); + auto gop_type_it = OpRegistry::grad_ops().find(op->type_); + PADDLE_ENFORCE(gop_type_it != OpRegistry::grad_ops().end(), + "Operator %s do not register gradient type", op->type_); + auto& grad_op_type = gop_type_it->second; + OperatorBase::VarNameMap inputs; + OperatorBase::VarNameMap outputs; + TransOpArg(op, &inputs, OpArgType::IN, false); // I + TransOpArg(op, &inputs, OpArgType::OUT, false); // O + TransOpArg(op, &inputs, OpArgType::OUT, true); // OG + TransOpArg(op, &outputs, OpArgType::IN, true); // IG + auto gop_it = OpRegistry::op_creators().find(grad_op_type); + PADDLE_ENFORCE(gop_it != OpRegistry::op_creators().end(), + "Operator %s 's Gradient %s's creator cannot be found", + op->type_, grad_op_type); - OperatorBase* grad_op = OpRegistry::op_creators().at(grad_op_type)( - grad_op_type, GradIn, IG, op->Attrs()); - return grad_op; + return gop_it->second(grad_op_type, inputs, outputs, op->attrs_); } } // namespace framework diff --git a/paddle/framework/grad_op_builder_test.cc b/paddle/framework/grad_op_builder_test.cc index 75c6ec8b5672242d3ac008dd7ba663e35ca530a0..ebaf84545fce0d281d8821861264cddc8854893d 100644 --- a/paddle/framework/grad_op_builder_test.cc +++ b/paddle/framework/grad_op_builder_test.cc @@ -10,7 +10,7 @@ namespace framework { class NOP : public OperatorBase { public: - DEFINE_OPERATOR_CTOR(NOP, OperatorBase); + using OperatorBase::OperatorBase; void InferShape(const Scope &scope) const override {} void Run(const Scope &scope, const platform::DeviceContext &dev_ctx) const override {} diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index ffd48160b8af385961440873c8e9b525e1f5618b..af965df7ec9962f314758ecf897f918ed463d864 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -122,8 +122,8 @@ class OpProtoAndCheckerMaker { class OpRegistry { using VarNameMap = OperatorBase::VarNameMap; using OpCreator = std::function; + const std::string& /*type*/, const VarNameMap& /*inputs*/, + const VarNameMap& /*outputs*/, const AttributeMap& /*attrs*/)>; public: template @@ -158,7 +158,7 @@ class OpRegistry { static std::shared_ptr CreateOp(const std::string& type, const VarNameMap& inputs, const VarNameMap& outputs, - const AttributeMap& attrs) { + AttributeMap attrs) { auto op_create_it = op_creators().find(type); PADDLE_ENFORCE(op_create_it != op_creators().end(), "Operator %s cannot be found.", type); @@ -168,7 +168,6 @@ class OpRegistry { auto op = op_create_it->second(type, inputs, outputs, attrMap); GenerateTempVariableName(op); - op->Init(); return std::shared_ptr(op); } @@ -200,7 +199,6 @@ class OpRegistry { PADDLE_ENFORCE(!op.IsNetOp(), "Use framework::Backward to get backward ops"); std::shared_ptr grad_op(BuildGradOp(&op)); - grad_op->Init(); return grad_op; } diff --git a/paddle/framework/op_registry_test.cc b/paddle/framework/op_registry_test.cc index 17cbd8563ceef5cdfadb842efa3eb052c1e77151..0b8f8289490135b8976c38fa3fb3c2995c50416f 100644 --- a/paddle/framework/op_registry_test.cc +++ b/paddle/framework/op_registry_test.cc @@ -7,7 +7,7 @@ namespace paddle { namespace framework { class CosineOp : public OperatorBase { public: - DEFINE_OPERATOR_CTOR(CosineOp, OperatorBase); + using OperatorBase::OperatorBase; void Run(const Scope& scope, const platform::DeviceContext& dev_ctx) const override {} void InferShape(const Scope& scope) const override {} @@ -28,7 +28,7 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { class MyTestOp : public OperatorBase { public: - DEFINE_OPERATOR_CTOR(MyTestOp, OperatorBase); + using OperatorBase::OperatorBase; void InferShape(const Scope& scope) const override {} void Run(const Scope& scope, const platform::DeviceContext& dev_ctx) const override {} diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 351a544c0be443fba9bf813648a421dd0d365411..13442a72b9d77a4858b5d91dd7690e089ec7ed49 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -122,6 +122,23 @@ void OperatorBase::Rename(const std::string& old_name, } } +OperatorBase::OperatorBase(const std::string& type, + const OperatorBase::VarNameMap& inputs, + const OperatorBase::VarNameMap& outputs, + const AttributeMap& attrs) + : type_(type), inputs_(inputs), outputs_(outputs), attrs_(attrs) { + static std::atomic gUniqId(0UL); + for (auto& output : outputs_) { + for (auto& output_name : output.second) { + if (output_name == kTempVarName) { + output_name += type_; + output_name += "@"; + output_name += std::to_string(gUniqId.fetch_add(1)); + } + } + } +} + std::vector OperatorBase::OutputVars(bool has_intermediate) const { std::vector ret_val; if (has_intermediate) { diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index acff4f0ca007352ce6ebbb48c4035ea980d3ba32..90569c509fb44bc8b86385d15ab95733da14c448 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -66,10 +66,8 @@ class OperatorBase { public: using VarNameMap = std::map>; - OperatorBase() = default; OperatorBase(const std::string& type, const VarNameMap& inputs, - const VarNameMap& outputs, const AttributeMap& attrs) - : type_(type), inputs_(inputs), outputs_(outputs), attrs_(attrs) {} + const VarNameMap& outputs, const AttributeMap& attrs); OperatorBase(const OperatorBase& o) = delete; OperatorBase& operator=(const OperatorBase& o) = delete; @@ -86,10 +84,6 @@ class OperatorBase { virtual std::string DebugString() const; - /// Init will be called after CreateOperator, you can put some initialization - /// logic here. - virtual void Init() {} - /// InferShape infer the size of Variables used by this Operator with /// information inside scope virtual void InferShape(const Scope& scope) const = 0; @@ -138,15 +132,6 @@ class OperatorBase { AttributeMap attrs_; }; -#define DEFINE_OPERATOR_CTOR(Class, ParentClass) \ - public: \ - Class() : ParentClass() { /* TODO(yi): This constructor is to be removed. */ \ - } \ - Class(const std::string& type, const VarNameMap& inputs, \ - const VarNameMap& outputs, \ - const paddle::framework::AttributeMap& attrs) \ - : ParentClass(type, inputs, outputs, attrs) {} - class InferShapeContext { public: InferShapeContext(const OperatorBase& op, const Scope& scope) @@ -267,6 +252,10 @@ class ExecutionContext : public InferShapeContext { platform::Place GetPlace() const { return device_context_->GetPlace(); } + const platform::DeviceContext* device_context() const { + return device_context_; + } + const platform::DeviceContext* device_context_; }; @@ -286,8 +275,6 @@ class OpKernel { class OperatorWithKernel : public OperatorBase { public: - DEFINE_OPERATOR_CTOR(OperatorWithKernel, OperatorBase) - struct OpKernelKey { platform::Place place_; @@ -311,6 +298,10 @@ class OperatorWithKernel : public OperatorBase { using OpKernelMap = std::unordered_map, OpKernelHash>; + OperatorWithKernel(const std::string& type, const VarNameMap& inputs, + const VarNameMap& outputs, const AttributeMap& attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + void InferShape(const Scope& scope) const override { InferShape(InferShapeContext(*this, scope)); } diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index d975145a21d3c6dfac4b9405304561d58bec9f94..6804841587730d51d9cfad30a9de81401d36695b 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -22,10 +22,10 @@ namespace framework { static int op_run_num = 0; class OpWithoutKernelTest : public OperatorBase { - DEFINE_OPERATOR_CTOR(OpWithoutKernelTest, framework::OperatorBase) - public: - void Init() override { x = 1; } + OpWithoutKernelTest(const std::string& type, const VarNameMap& inputs, + const VarNameMap& outputs, const AttributeMap& attrs) + : OperatorBase(type, inputs, outputs, attrs), x(1) {} void InferShape(const Scope& scope) const override {} void Run(const Scope& scope, const platform::DeviceContext& dev_ctx) const override { @@ -38,7 +38,7 @@ class OpWithoutKernelTest : public OperatorBase { } public: - float x = 0; + int x{0}; }; class OpeWithoutKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker { @@ -109,7 +109,9 @@ class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker { static int cpu_kernel_run_num = 0; class OpWithKernelTest : public OperatorWithKernel { - DEFINE_OPERATOR_CTOR(OpWithKernelTest, framework::OperatorWithKernel) + public: + using OperatorWithKernel::OperatorWithKernel; + protected: void InferShape(const framework::InferShapeContext& ctx) const override {} }; diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index cd1b4de426a49fa66dbbf8cf7d09990ac8d21227..b8c779f4e5fc7bc51298cdd35b26c2c8ac98edf6 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -105,6 +105,8 @@ class Tensor { template inline Tensor Slice(const int& begin_idx, const int& end_idx) const; + platform::Place place() const { return holder_->place(); } + private: template inline void check_memory_size() const; diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 7f56aaa92cc45d81440084cdeb3c6eb3b6fda3df..373611cc0ee952de813f01d32d1516e1a8384750 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -41,6 +41,7 @@ function(op_library TARGET) endif() endfunction() +add_subdirectory(math) cc_test(gather_test SRCS gather_test.cc DEPS tensor) cc_library(net_op SRCS net_op.cc DEPS op_registry) @@ -50,7 +51,7 @@ op_library(add_op SRCS add_op.cc add_op.cu) op_library(mean_op SRCS mean_op.cc mean_op.cu) -op_library(mul_op SRCS mul_op.cc mul_op.cu) +op_library(mul_op SRCS mul_op.cc mul_op.cu DEPS math_function) op_library(rowwise_add_op SRCS rowwise_add_op.cu rowwise_add_op.cc) op_library(sigmoid_op SRCS sigmoid_op.cc sigmoid_op.cu) diff --git a/paddle/operators/add_op.cc b/paddle/operators/add_op.cc index bf0982e095138a62b655599167ea2ec715987667..c1f647a88e4547d96bbb9143cdb2cb07bc291635 100644 --- a/paddle/operators/add_op.cc +++ b/paddle/operators/add_op.cc @@ -18,7 +18,8 @@ namespace paddle { namespace operators { class AddOp : public framework::OperatorWithKernel { - DEFINE_OPERATOR_CTOR(AddOp, framework::OperatorWithKernel) + public: + using framework::OperatorWithKernel::OperatorWithKernel; protected: void InferShape(const framework::InferShapeContext &ctx) const override { @@ -45,7 +46,9 @@ The equation is: Out = X + Y }; class AddOpGrad : public framework::OperatorWithKernel { - DEFINE_OPERATOR_CTOR(AddOpGrad, framework::OperatorWithKernel) + public: + using framework::OperatorWithKernel::OperatorWithKernel; + protected: void InferShape(const framework::InferShapeContext &ctx) const override {} }; diff --git a/paddle/operators/cross_entropy_op.cc b/paddle/operators/cross_entropy_op.cc index e40351a1c1abf293b9b6dab4545ae547ebc1d7de..597c71d4e042e6b6a752c0b1819b909a7a9faa75 100644 --- a/paddle/operators/cross_entropy_op.cc +++ b/paddle/operators/cross_entropy_op.cc @@ -18,7 +18,9 @@ namespace paddle { namespace operators { class OnehotCrossEntropyOp : public framework::OperatorWithKernel { - DEFINE_OPERATOR_CTOR(OnehotCrossEntropyOp, framework::OperatorWithKernel) + public: + using framework::OperatorWithKernel::OperatorWithKernel; + protected: void InferShape(const framework::InferShapeContext &ctx) const override { auto *X = ctx.Input("X"); @@ -32,8 +34,9 @@ class OnehotCrossEntropyOp : public framework::OperatorWithKernel { }; class OnehotCrossEntropyGradientOp : public framework::OperatorWithKernel { - DEFINE_OPERATOR_CTOR(OnehotCrossEntropyGradientOp, - framework::OperatorWithKernel) + public: + using framework::OperatorWithKernel::OperatorWithKernel; + protected: void InferShape(const framework::InferShapeContext &ctx) const override { auto X_grad = ctx.Output(framework::GradVarName("X")); diff --git a/paddle/operators/fill_zeros_like_op.cc b/paddle/operators/fill_zeros_like_op.cc index 881d4128bba508af44bdd887c4cfd99231ed1127..e42e33f1a3759ae26cee987d0b68a55b672e3f94 100644 --- a/paddle/operators/fill_zeros_like_op.cc +++ b/paddle/operators/fill_zeros_like_op.cc @@ -18,7 +18,8 @@ namespace paddle { namespace operators { class FillZerosLikeOp : public framework::OperatorWithKernel { - DEFINE_OPERATOR_CTOR(FillZerosLikeOp, framework::OperatorWithKernel); + public: + using framework::OperatorWithKernel::OperatorWithKernel; protected: void InferShape(const framework::InferShapeContext &ctx) const override { diff --git a/paddle/operators/gaussian_random_op.cc b/paddle/operators/gaussian_random_op.cc index 9a4d4addd40ee90797cf3f3bcf469ec4bdf4c88e..75249c08eb00095615fc75eb9261432d64246b2e 100644 --- a/paddle/operators/gaussian_random_op.cc +++ b/paddle/operators/gaussian_random_op.cc @@ -43,7 +43,8 @@ class GaussianRandomKernel : public framework::OpKernel { }; class GaussianRandomOp : public framework::OperatorWithKernel { - DEFINE_OPERATOR_CTOR(GaussianRandomOp, framework::OperatorWithKernel); + public: + using framework::OperatorWithKernel::OperatorWithKernel; protected: void InferShape(const framework::InferShapeContext& context) const override { diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..abcaf940ab0128d6948acc620d678632c8f48960 --- /dev/null +++ b/paddle/operators/math/CMakeLists.txt @@ -0,0 +1,13 @@ +if(WITH_MKLML) + set(BLAS_LIB mklml) +else() + set(BLAS_LIB cblas) +endif() + +if(WITH_GPU) + nv_library(math_function SRCS math_function.cc math_function.cu DEPS ${BLAS_LIB} device_context) +else() + cc_library(math_function SRCS math_function.cc DEPS ${BLAS_LIB} device_context) +endif() + +nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) diff --git a/paddle/operators/math/math_function.cc b/paddle/operators/math/math_function.cc new file mode 100644 index 0000000000000000000000000000000000000000..affdd1ac2cd486930881ee6b34a4b32f41df7ee9 --- /dev/null +++ b/paddle/operators/math/math_function.cc @@ -0,0 +1,114 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/math_function.h" + +namespace paddle { +namespace operators { +namespace math { + +template <> +void gemm(const CBLAS_TRANSPOSE transA, + const CBLAS_TRANSPOSE transB, const int M, + const int N, const int K, + const float alpha, const float* A, + const float* B, const float beta, float* C, + platform::DeviceContext* context) { + int lda = K; + int ldb = N; + int ldc = N; + cblas_sgemm(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb, + beta, C, ldc); +} + +template <> +void gemm(const CBLAS_TRANSPOSE transA, + const CBLAS_TRANSPOSE transB, const int M, + const int N, const int K, + const double alpha, const double* A, + const double* B, const double beta, + double* C, + platform::DeviceContext* context) { + int lda = K; + int ldb = N; + int ldc = N; + cblas_dgemm(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb, + beta, C, ldc); +} + +template <> +void matmul(const framework::Tensor& matrix_a, + bool trans_a, + const framework::Tensor& matrix_b, + bool trans_b, float alpha, + framework::Tensor* matrix_out, + float beta, + platform::DeviceContext* context) { + auto dim_a = matrix_a.dims(); + auto dim_b = matrix_b.dims(); + auto dim_out = matrix_out->dims(); + PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2, + "The input and output of matmul be matrix"); + + PADDLE_ENFORCE(platform::is_cpu_place(matrix_a.place()) && + platform::is_cpu_place(matrix_b.place()) && + platform::is_cpu_place(matrix_out->place()), + "Matrix must all be in CPUPlace"); + + int M = dim_out[0]; + int N = dim_out[1]; + int K = (trans_a == false) ? dim_a[1] : dim_a[0]; + + CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans; + CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans; + + gemm( + transA, transB, M, N, K, alpha, matrix_a.data(), + matrix_b.data(), beta, matrix_out->data(), context); +} + +template <> +void matmul(const framework::Tensor& matrix_a, + bool trans_a, + const framework::Tensor& matrix_b, + bool trans_b, double alpha, + framework::Tensor* matrix_out, + double beta, + platform::DeviceContext* context) { + auto dim_a = matrix_a.dims(); + auto dim_b = matrix_b.dims(); + auto dim_out = matrix_out->dims(); + PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2, + "The input and output of matmul be matrix"); + + PADDLE_ENFORCE(platform::is_cpu_place(matrix_a.place()) && + platform::is_cpu_place(matrix_b.place()) && + platform::is_cpu_place(matrix_out->place()), + "Matrix must all be in CPUPlace"); + + int M = dim_out[0]; + int N = dim_out[1]; + int K = (trans_a == false) ? dim_a[1] : dim_a[0]; + + CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans; + CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans; + + gemm( + transA, transB, M, N, K, alpha, matrix_a.data(), + matrix_b.data(), beta, matrix_out->data(), context); +} + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu new file mode 100644 index 0000000000000000000000000000000000000000..da40b27c948918e4997f4a046d2145552296158b --- /dev/null +++ b/paddle/operators/math/math_function.cu @@ -0,0 +1,127 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/math_function.h" + +namespace paddle { +namespace operators { +namespace math { + +template <> +void gemm(const CBLAS_TRANSPOSE transA, + const CBLAS_TRANSPOSE transB, const int M, + const int N, const int K, + const float alpha, const float* A, + const float* B, const float beta, float* C, + platform::DeviceContext* context) { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + cublasOperation_t cuTransA = + (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasOperation_t cuTransB = + (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + + PADDLE_ENFORCE(platform::dynload::cublasSgemm( + reinterpret_cast(context)->cublas_handle(), + cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, N)); +} + +template <> +void gemm(const CBLAS_TRANSPOSE transA, + const CBLAS_TRANSPOSE transB, const int M, + const int N, const int K, + const double alpha, const double* A, + const double* B, const double beta, + double* C, + platform::DeviceContext* context) { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + cublasOperation_t cuTransA = + (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasOperation_t cuTransB = + (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + PADDLE_ENFORCE(platform::dynload::cublasDgemm( + reinterpret_cast(context)->cublas_handle(), + cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, N)); +} + +template <> +void matmul(const framework::Tensor& matrix_a, + bool trans_a, + const framework::Tensor& matrix_b, + bool trans_b, float alpha, + framework::Tensor* matrix_out, + float beta, + platform::DeviceContext* context) { + auto dim_a = matrix_a.dims(); + auto dim_b = matrix_b.dims(); + auto dim_out = matrix_out->dims(); + PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2, + "The input and output of matmul be matrix"); + + PADDLE_ENFORCE(platform::is_gpu_place(matrix_a.place()) && + platform::is_gpu_place(matrix_b.place()) && + platform::is_gpu_place(matrix_out->place()), + "Matrix must all be in GPUPlace"); + + int M = dim_out[0]; + int N = dim_out[1]; + int K = (trans_a == false) ? dim_a[1] : dim_a[0]; + + CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans; + CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans; + + gemm( + transA, transB, M, N, K, alpha, matrix_a.data(), + matrix_b.data(), beta, matrix_out->data(), context); +} + +template <> +void matmul(const framework::Tensor& matrix_a, + bool trans_a, + const framework::Tensor& matrix_b, + bool trans_b, double alpha, + framework::Tensor* matrix_out, + double beta, + platform::DeviceContext* context) { + auto dim_a = matrix_a.dims(); + auto dim_b = matrix_b.dims(); + auto dim_out = matrix_out->dims(); + PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2, + "The input and output of matmul be matrix"); + + PADDLE_ENFORCE(platform::is_gpu_place(matrix_a.place()) && + platform::is_gpu_place(matrix_b.place()) && + platform::is_gpu_place(matrix_out->place()), + "Matrix must all be in GPUPlace"); + + int M = dim_out[0]; + int N = dim_out[1]; + int K = (trans_a == false) ? dim_a[1] : dim_a[0]; + + CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans; + CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans; + + gemm( + transA, transB, M, N, K, alpha, matrix_a.data(), + matrix_b.data(), beta, matrix_out->data(), context); +} + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/math_function.h b/paddle/operators/math/math_function.h new file mode 100644 index 0000000000000000000000000000000000000000..155589fadb3ed9f59160a750d546dd8093a56cbe --- /dev/null +++ b/paddle/operators/math/math_function.h @@ -0,0 +1,82 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#ifdef PADDLE_USE_MKLML +#include +#include +#include +#endif + +#ifdef PADDLE_USE_MKL +#include +#include +#endif + +#ifdef PADDLE_USE_ATLAS +extern "C" { +#include +#include +} +#endif + +#ifdef PADDLE_USE_OPENBLAS +#include +#include +#endif + +#ifndef LAPACK_FOUND +extern "C" { +#include +int LAPACKE_sgetrf(int matrix_layout, int m, int n, float* a, int lda, + int* ipiv); +int LAPACKE_dgetrf(int matrix_layout, int m, int n, double* a, int lda, + int* ipiv); +int LAPACKE_sgetri(int matrix_layout, int n, float* a, int lda, + const int* ipiv); +int LAPACKE_dgetri(int matrix_layout, int n, double* a, int lda, + const int* ipiv); +} +#endif + +#include + +#include "paddle/framework/tensor.h" +#include "paddle/platform/device_context.h" +#include "paddle/platform/enforce.h" + +namespace paddle { +namespace operators { +namespace math { + +// Support continuous memory now +// If transA = N, and transB = N +// Then matrixA: M * K, matrixB: K * N matrixC : M * N +// For more detailed info, please refer to +// http://www.netlib.org/lapack/explore-html/d4/de2/sgemm_8f.html +template +void gemm(const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, + const int M, const int N, const int K, const T alpha, const T* A, + const T* B, const T beta, T* C, platform::DeviceContext* context); + +// matrix multiply with continuous memory +template +void matmul(const framework::Tensor& matrix_a, bool trans_a, + const framework::Tensor& matrix_b, bool trans_b, T alpha, + framework::Tensor* matrix_out, T beta, + platform::DeviceContext* context); + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/math_function_test.cc b/paddle/operators/math/math_function_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..6c020c4ff7285b43bc5836d80c173d3a068e72b3 --- /dev/null +++ b/paddle/operators/math/math_function_test.cc @@ -0,0 +1,75 @@ +#include "paddle/operators/math/math_function.h" +#include "gtest/gtest.h" + +#ifndef PADDLE_ONLY_CPU +TEST(math_function, notrans_mul_trans) { + paddle::framework::Tensor input1; + paddle::framework::Tensor input1_gpu; + paddle::framework::Tensor input2_gpu; + paddle::framework::Tensor out_gpu; + paddle::framework::Tensor out; + + auto* cpu_place = new paddle::platform::CPUPlace(); + float* input1_ptr = input1.mutable_data({2, 3}, *cpu_place); + float arr[6] = {0, 1, 2, 3, 4, 5}; + memcpy(input1_ptr, arr, 6 * sizeof(float)); + + auto* gpu_place = new paddle::platform::GPUPlace(0); + paddle::platform::DeviceContext* context = + new paddle::platform::CUDADeviceContext(*gpu_place); + + input1_gpu.CopyFrom(input1, *gpu_place); + input2_gpu.CopyFrom(input1, *gpu_place); + + out_gpu.mutable_data({2, 2}, *gpu_place); + + paddle::operators::math::matmul( + input1_gpu, false, input2_gpu, true, 1, &out_gpu, 0, context); + + out.CopyFrom(out_gpu, *cpu_place); + + float* out_ptr = out.data(); + EXPECT_EQ(out_ptr[0], 5); + EXPECT_EQ(out_ptr[1], 14); + EXPECT_EQ(out_ptr[2], 14); + EXPECT_EQ(out_ptr[3], 50); +} + +TEST(math_function, trans_mul_notrans) { + paddle::framework::Tensor input1; + paddle::framework::Tensor input1_gpu; + paddle::framework::Tensor input2_gpu; + paddle::framework::Tensor out_gpu; + paddle::framework::Tensor out; + + auto* cpu_place = new paddle::platform::CPUPlace(); + float* input1_ptr = input1.mutable_data({2, 3}, *cpu_place); + float arr[6] = {0, 1, 2, 3, 4, 5}; + memcpy(input1_ptr, arr, 6 * sizeof(float)); + + auto* gpu_place = new paddle::platform::GPUPlace(0); + paddle::platform::DeviceContext* context = + new paddle::platform::CUDADeviceContext(*gpu_place); + + input1_gpu.CopyFrom(input1, *gpu_place); + input2_gpu.CopyFrom(input1, *gpu_place); + + out_gpu.mutable_data({3, 3}, *gpu_place); + + paddle::operators::math::matmul( + input1_gpu, true, input2_gpu, false, 1, &out_gpu, 0, context); + + out.CopyFrom(out_gpu, *cpu_place); + + float* out_ptr = out.data(); + EXPECT_EQ(out_ptr[0], 9); + EXPECT_EQ(out_ptr[1], 12); + EXPECT_EQ(out_ptr[2], 15); + EXPECT_EQ(out_ptr[3], 12); + EXPECT_EQ(out_ptr[4], 17); + EXPECT_EQ(out_ptr[5], 22); + EXPECT_EQ(out_ptr[6], 15); + EXPECT_EQ(out_ptr[7], 22); + EXPECT_EQ(out_ptr[8], 29); +} +#endif diff --git a/paddle/operators/mean_op.cc b/paddle/operators/mean_op.cc index 3b258a6bd02a192071e6fc171b724959c498adf6..35e7212dde210a50285272cfd94118fa34fb7cd9 100644 --- a/paddle/operators/mean_op.cc +++ b/paddle/operators/mean_op.cc @@ -18,7 +18,9 @@ namespace paddle { namespace operators { class MeanOp : public framework::OperatorWithKernel { - DEFINE_OPERATOR_CTOR(MeanOp, framework::OperatorWithKernel) + public: + using framework::OperatorWithKernel::OperatorWithKernel; + protected: void InferShape(const framework::InferShapeContext &ctx) const override { PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), @@ -38,7 +40,9 @@ class MeanOpMaker : public framework::OpProtoAndCheckerMaker { }; class MeanGradOp : public framework::OperatorWithKernel { - DEFINE_OPERATOR_CTOR(MeanGradOp, framework::OperatorWithKernel) + public: + using framework::OperatorWithKernel::OperatorWithKernel; + protected: void InferShape(const framework::InferShapeContext &ctx) const override { ctx.Output(framework::GradVarName("X")) diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc index ae924375c2fb27104ffeb98268aec36fafde3c69..032d234197c12fe107fb195e862c160948ee354c 100644 --- a/paddle/operators/mul_op.cc +++ b/paddle/operators/mul_op.cc @@ -13,12 +13,14 @@ limitations under the License. */ #include "paddle/operators/mul_op.h" +#include "paddle/operators/math/math_function.h" namespace paddle { namespace operators { class MulOp : public framework::OperatorWithKernel { - DEFINE_OPERATOR_CTOR(MulOp, framework::OperatorWithKernel); + public: + using framework::OperatorWithKernel::OperatorWithKernel; protected: void InferShape(const framework::InferShapeContext &ctx) const override { @@ -53,7 +55,9 @@ The equation is: Out = X * Y }; class MulOpGrad : public framework::OperatorWithKernel { - DEFINE_OPERATOR_CTOR(MulOpGrad, framework::OperatorWithKernel) + public: + using framework::OperatorWithKernel::OperatorWithKernel; + protected: void InferShape(const framework::InferShapeContext &ctx) const override {} std::string DebugString() const override { diff --git a/paddle/operators/mul_op.cu b/paddle/operators/mul_op.cu index 43debbc21a365a15c914e60e151f7782b82080cb..346a7e505d123b5e4e831daa39a1f6349b3dcccf 100644 --- a/paddle/operators/mul_op.cu +++ b/paddle/operators/mul_op.cu @@ -16,5 +16,4 @@ #include "paddle/operators/mul_op.h" namespace ops = paddle::operators; - REGISTER_OP_GPU_KERNEL(mul, ops::MulKernel); diff --git a/paddle/operators/mul_op.h b/paddle/operators/mul_op.h index ca3105fa4f158064c822a319e2c9c5a40e39d481..b7812fd1a7a72f5ce543e18c8b7b5b51deff2204 100644 --- a/paddle/operators/mul_op.h +++ b/paddle/operators/mul_op.h @@ -13,6 +13,9 @@ limitations under the License. */ #pragma once + +#include "paddle/operators/math/math_function.h" + #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" diff --git a/paddle/operators/net_op.cc b/paddle/operators/net_op.cc index 61e1377af8ae5b93c08a5920e182e6f8da4376d1..c36fe8d6b58a0afa568e31e43567baa5f261c7d0 100644 --- a/paddle/operators/net_op.cc +++ b/paddle/operators/net_op.cc @@ -81,5 +81,11 @@ std::vector NetOp::OutputVars(bool has_intermediate) const { return ret_val; } +NetOp::NetOp(const std::string& type, + const framework::OperatorBase::VarNameMap& inputs, + const framework::OperatorBase::VarNameMap& outputs, + const framework::AttributeMap& attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + } // namespace operators } // namespace paddle diff --git a/paddle/operators/net_op.h b/paddle/operators/net_op.h index 4560578121ed28ac1e150e1ffdd41bda97050f67..4a3408c158a029a96740717280c1562671fa938f 100644 --- a/paddle/operators/net_op.h +++ b/paddle/operators/net_op.h @@ -37,7 +37,9 @@ namespace operators { class NetOp : public framework::OperatorBase { public: static const char kAll[]; - DEFINE_OPERATOR_CTOR(NetOp, framework::OperatorBase); + NetOp() : framework::OperatorBase("plain_net", {}, {}, {}) {} + NetOp(const std::string& type, const VarNameMap& inputs, + const VarNameMap& outputs, const framework::AttributeMap& attrs); /** * Infer all the operators' input and output variables' shapes, will be called diff --git a/paddle/operators/net_op_test.cc b/paddle/operators/net_op_test.cc index 8872c8d92baea2912b96faffdd2075bcf249f77b..f7aa56262ef71c24bf668950f6e9914e5f96ff70 100644 --- a/paddle/operators/net_op_test.cc +++ b/paddle/operators/net_op_test.cc @@ -12,7 +12,7 @@ static int run_cnt = 0; class TestOp : public framework::OperatorBase { public: - DEFINE_OPERATOR_CTOR(TestOp, framework::OperatorBase); + using framework::OperatorBase::OperatorBase; void InferShape(const Scope& scope) const override { ++infer_shape_cnt; } void Run(const Scope& scope, const platform::DeviceContext& dev_ctx) const override { @@ -22,7 +22,7 @@ class TestOp : public framework::OperatorBase { class EmptyOp : public framework::OperatorBase { public: - DEFINE_OPERATOR_CTOR(EmptyOp, framework::OperatorBase); + using framework::OperatorBase::OperatorBase; void InferShape(const Scope& scope) const override {} void Run(const Scope& scope, const DeviceContext& dev_ctx) const override {} }; @@ -44,14 +44,14 @@ TEST(OpKernel, all) { auto net = std::make_shared(); ASSERT_NE(net, nullptr); - auto op1 = std::make_shared(); - op1->inputs_ = {{"X", {"x"}}, {"W", {"w1"}}, {"b", {"b1"}}}; - op1->outputs_ = {{"Out", {"y"}}}; + auto op1 = std::shared_ptr( + new TestOp("test", {{"X", {"x"}}, {"W", {"w1"}}, {"b", {"b1"}}}, + {{"Out", {"y"}}}, {})); net->AddOp(op1); - auto op2 = std::make_shared(); - op2->inputs_ = {{"X", {"y"}}, {"W", {"w2"}}, {"b", {"b2"}}}; - op2->outputs_ = {{"Out", {"z"}}}; + auto op2 = std::shared_ptr( + new TestOp("test", {{"X", {"y"}}, {"W", {"w2"}}, {"b", {"b2"}}}, + {{"Out", {"z"}}}, {})); net->AddOp(op2); net->CompleteAddOp(); @@ -67,9 +67,9 @@ TEST(OpKernel, all) { TEST(NetOp, insert_op) { NetOp net; - auto op1 = std::make_shared(); - op1->inputs_ = {{"X", {"x"}}, {"W", {"w1"}}, {"b", {"b1"}}}; - op1->outputs_ = {{"Out", {"y"}}}; + auto op1 = std::shared_ptr( + new EmptyOp("empty", {{"X", {"x"}}, {"W", {"w1"}}, {"b", {"b1"}}}, + {{"Out", {"y"}}}, {})); net.AddOp(op1); net.InsertOp(0, op1); ASSERT_EQ(2UL, net.ops_.size()); diff --git a/paddle/operators/recurrent_op.cc b/paddle/operators/recurrent_op.cc index 5e6ba6b8dd0c71359a039a4b777feba6aab606f7..5ddee75581824996fd312f8ddf13007759fd9a67 100644 --- a/paddle/operators/recurrent_op.cc +++ b/paddle/operators/recurrent_op.cc @@ -135,8 +135,11 @@ const rnn::ArgumentName RecurrentGradientOp::kArgName{ "inlink@grad", "inlink_alias", "outlink_alias", "memories", "pre_memories", "boot_memories@grad"}; -void RecurrentOp::Init() { - OperatorBase::Init(); +RecurrentOp::RecurrentOp(const std::string& type, + const framework::OperatorBase::VarNameMap& inputs, + const framework::OperatorBase::VarNameMap& outputs, + const framework::AttributeMap& attrs) + : OperatorBase(type, inputs, outputs, attrs) { std::unique_ptr arg(new rnn::Argument()); rnn::InitArgument(kArgName, arg.get(), *this); alg_.Init(std::move(arg)); @@ -230,8 +233,11 @@ void RecurrentGradientAlgorithm::InferShape(const Scope& scope) const { LinkBootMemoryGradients(step_scopes[0], true /*infer_shape_mode*/); } -void RecurrentGradientOp::Init() { - OperatorBase::Init(); +RecurrentGradientOp::RecurrentGradientOp( + const std::string& type, const framework::OperatorBase::VarNameMap& inputs, + const framework::OperatorBase::VarNameMap& outputs, + const framework::AttributeMap& attrs) + : OperatorBase(type, inputs, outputs, attrs) { std::unique_ptr arg(new rnn::Argument()); rnn::InitArgument(kArgName, arg.get(), *this); alg_.Init(std::move(arg)); diff --git a/paddle/operators/recurrent_op.h b/paddle/operators/recurrent_op.h index b22ac0ddc9a5f6d3137bfb02e39c84a4b1517c7c..8f4f2444d844b4ba5948f001a365a7ecaeecc106 100644 --- a/paddle/operators/recurrent_op.h +++ b/paddle/operators/recurrent_op.h @@ -101,13 +101,11 @@ class RecurrentGradientAlgorithm { class RecurrentOp final : public framework::OperatorBase { public: - DEFINE_OPERATOR_CTOR(RecurrentOp, framework::OperatorBase); - - void Init() override; - + RecurrentOp(const std::string& type, const VarNameMap& inputs, + const VarNameMap& outputs, const framework::AttributeMap& attrs); /** - * InferShape must be called before Run. - */ + * InferShape must be called before Run. + */ void InferShape(const framework::Scope& scope) const override { alg_.InferShape(scope); } @@ -125,8 +123,9 @@ class RecurrentOp final : public framework::OperatorBase { class RecurrentGradientOp final : public framework::OperatorBase { public: - DEFINE_OPERATOR_CTOR(RecurrentGradientOp, framework::OperatorBase) - void Init() override; + RecurrentGradientOp(const std::string& type, const VarNameMap& inputs, + const VarNameMap& outputs, + const framework::AttributeMap& attrs); /** * InferShape must be called before Run. diff --git a/paddle/operators/rowwise_add_op.cc b/paddle/operators/rowwise_add_op.cc index fcc6e163641ccb5b9df1d2e6e84a53eef1791cef..b4671c293af1c4fed3b441f05bc8f3a5db039b41 100644 --- a/paddle/operators/rowwise_add_op.cc +++ b/paddle/operators/rowwise_add_op.cc @@ -18,7 +18,9 @@ namespace paddle { namespace operators { class RowWiseAddOp : public framework::OperatorWithKernel { - DEFINE_OPERATOR_CTOR(RowWiseAddOp, framework::OperatorWithKernel) + public: + using framework::OperatorWithKernel::OperatorWithKernel; + protected: void InferShape(const framework::InferShapeContext &ctx) const override { auto dim0 = ctx.Input("X")->dims(); diff --git a/paddle/operators/sgd_op.cc b/paddle/operators/sgd_op.cc index 29a6a77006430ecde6e44dc0cd28e940871f484d..bf76df272b6faaed01ed8d715fe3b547ec7dc4e3 100644 --- a/paddle/operators/sgd_op.cc +++ b/paddle/operators/sgd_op.cc @@ -18,7 +18,9 @@ namespace paddle { namespace operators { class SGDOp : public framework::OperatorWithKernel { - DEFINE_OPERATOR_CTOR(SGDOp, framework::OperatorWithKernel) + public: + using framework::OperatorWithKernel::OperatorWithKernel; + protected: void InferShape(const framework::InferShapeContext &ctx) const override { PADDLE_ENFORCE( diff --git a/paddle/operators/sigmoid_op.cc b/paddle/operators/sigmoid_op.cc index 40a8ba12d7e32b5db1ab14b7ab647c327e65c6fe..a7dfb624e5b779164eb07763eb604c548f6e89e7 100644 --- a/paddle/operators/sigmoid_op.cc +++ b/paddle/operators/sigmoid_op.cc @@ -18,7 +18,9 @@ namespace paddle { namespace operators { class SigmoidOp : public framework::OperatorWithKernel { - DEFINE_OPERATOR_CTOR(SigmoidOp, framework::OperatorWithKernel) + public: + using framework::OperatorWithKernel::OperatorWithKernel; + protected: void InferShape(const framework::InferShapeContext &ctx) const override { ctx.Output("Y")->Resize(ctx.Input("X")->dims()); @@ -37,7 +39,9 @@ class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker { }; class SigmoidOpGrad : public framework::OperatorWithKernel { - DEFINE_OPERATOR_CTOR(SigmoidOpGrad, framework::OperatorWithKernel) + public: + using framework::OperatorWithKernel::OperatorWithKernel; + protected: void InferShape(const framework::InferShapeContext &ctx) const override { ctx.Output(0)->Resize(ctx.Input(0)->dims()); diff --git a/paddle/operators/softmax_op.cc b/paddle/operators/softmax_op.cc index 16351b4bbd98e68d5b22f82f61f7b700ca90b559..5d8ece1a254a58990bfb2f919567fa43689335b9 100644 --- a/paddle/operators/softmax_op.cc +++ b/paddle/operators/softmax_op.cc @@ -18,7 +18,9 @@ namespace paddle { namespace operators { class SoftmaxOp : public framework::OperatorWithKernel { - DEFINE_OPERATOR_CTOR(SoftmaxOp, framework::OperatorWithKernel) + public: + using framework::OperatorWithKernel::OperatorWithKernel; + protected: void InferShape(const framework::InferShapeContext &ctx) const override { PADDLE_ENFORCE(ctx.Input("X")->dims().size() == 2UL, @@ -39,7 +41,9 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker { }; class SoftmaxOpGrad : public framework::OperatorWithKernel { - DEFINE_OPERATOR_CTOR(SoftmaxOpGrad, framework::OperatorWithKernel) + public: + using framework::OperatorWithKernel::OperatorWithKernel; + protected: void InferShape(const framework::InferShapeContext &ctx) const override { PADDLE_ENFORCE(ctx.InputVar("Y") != nullptr, "Input(Y) should not be null"); diff --git a/paddle/operators/uniform_random_op.cc b/paddle/operators/uniform_random_op.cc index 8c40eed9d4c7b80b3a111f240fdab1f28f73ee06..9d668e6085b93bc5a3a06683aa4470f62ae47c02 100644 --- a/paddle/operators/uniform_random_op.cc +++ b/paddle/operators/uniform_random_op.cc @@ -46,7 +46,9 @@ class CPUUniformRandomKernel : public framework::OpKernel { }; class UniformRandomOp : public framework::OperatorWithKernel { - DEFINE_OPERATOR_CTOR(UniformRandomOp, framework::OperatorWithKernel) + public: + using framework::OperatorWithKernel::OperatorWithKernel; + protected: void InferShape(const framework::InferShapeContext& ctx) const override { PADDLE_ENFORCE(GetAttr("min") < GetAttr("max"), diff --git a/paddle/platform/dynload/cublas.h b/paddle/platform/dynload/cublas.h index aad8097dbb33cbf6c0f2b4b3efb1376fbe96bc74..9d8343c0b5e200b390ccda760f09816959952e9d 100644 --- a/paddle/platform/dynload/cublas.h +++ b/paddle/platform/dynload/cublas.h @@ -62,12 +62,12 @@ extern void *cublas_dso_handle; DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(__name) #define CUBLAS_BLAS_ROUTINE_EACH(__macro) \ - __macro(cublasSgemv); \ - __macro(cublasDgemv); \ - __macro(cublasSgemm); \ - __macro(cublasDgemm); \ - __macro(cublasSgeam); \ - __macro(cublasDgeam); \ + __macro(cublasSgemv_v2); \ + __macro(cublasDgemv_v2); \ + __macro(cublasSgemm_v2); \ + __macro(cublasDgemm_v2); \ + __macro(cublasSgeam_v2); \ + __macro(cublasDgeam_v2); \ __macro(cublasCreate_v2); \ __macro(cublasDestroy_v2); \ __macro(cublasSetStream_v2); \ diff --git a/paddle/platform/enforce.h b/paddle/platform/enforce.h index 337a059fb1494d500be0fd2437e59c863ae1563c..15fdf7a94f462a87f7edae1429eb0c4da0b17a84 100644 --- a/paddle/platform/enforce.h +++ b/paddle/platform/enforce.h @@ -14,14 +14,21 @@ limitations under the License. */ #pragma once -#include +#include // for dladdr +#include // for backtrace #include +#include #include #include #include + #include "paddle/string/printf.h" #include "paddle/string/to_string.h" +#ifdef __GNUC__ +#include // for __cxa_demangle +#endif + #ifndef PADDLE_ONLY_CPU #include "paddle/platform/dynload/cublas.h" @@ -39,6 +46,19 @@ limitations under the License. */ namespace paddle { namespace platform { +namespace { +#ifdef __GNUC__ +inline std::string demangle(std::string name) { + int status = -4; // some arbitrary value to eliminate the compiler warning + std::unique_ptr res{ + abi::__cxa_demangle(name.c_str(), NULL, NULL, &status), std::free}; + return (status == 0) ? res.get() : name; +} +#else +inline std::string demangle(std::string name) { return name; } +#endif +} + struct EnforceNotMet : public std::exception { std::exception_ptr exp_; std::string err_str_; @@ -48,15 +68,29 @@ struct EnforceNotMet : public std::exception { std::rethrow_exception(exp_); } catch (const std::exception& exp) { std::ostringstream sout; + sout << string::Sprintf("%s at [%s:%d]", exp.what(), f, l) << std::endl; - sout << "Call Stacks: " << std::endl; + sout << "PaddlePaddle Call Stacks: " << std::endl; + void* call_stack[TRACE_STACK_LIMIT]; - int sz = backtrace(call_stack, TRACE_STACK_LIMIT); - auto line = backtrace_symbols(call_stack, sz); - for (int i = 0; i < sz; ++i) { - sout << line[i] << std::endl; + auto size = backtrace(call_stack, TRACE_STACK_LIMIT); + auto symbols = backtrace_symbols(call_stack, size); + + Dl_info info; + for (int i = 0; i < size; ++i) { + if (dladdr(call_stack[i], &info)) { + auto demangled = demangle(info.dli_sname); + auto addr_offset = static_cast(call_stack[i]) - + static_cast(info.dli_saddr); + sout << string::Sprintf("%-3d %*0p %s + %zd\n", i, + 2 + sizeof(void*) * 2, call_stack[i], + demangled, addr_offset); + } else { + sout << string::Sprintf("%-3d %*0p %s\n", i, 2 + sizeof(void*) * 2, + call_stack[i]); + } } - free(line); + free(symbols); err_str_ = sout.str(); } } @@ -170,7 +204,7 @@ inline void throw_on_error(T e) { * PADDLE_ENFORCE_EQ(a, b); * * will raise an expression described as follows: - * "enforce a == b failed, 1 != 2" with detailed stack infomation. + * "enforce a == b failed, 1 != 2" with detailed stack information. * * extra messages is also supported, for example: * PADDLE_ENFORCE(a, b, "some simple enforce failed between %d numbers", 2)