提交 2f74e608 编写于 作者: Q qiaolongfei

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into public_to_protected

......@@ -30,7 +30,7 @@ using DeviceContext = platform::DeviceContext;
class EmptyOp : public OperatorBase {
public:
DEFINE_OPERATOR_CTOR(EmptyOp, OperatorBase);
using OperatorBase::OperatorBase;
void InferShape(const Scope &scope) const override {}
void Run(const Scope &scope, const DeviceContext &dev_ctx) const override {}
};
......@@ -79,8 +79,9 @@ class NoGradOpMaker : public OpProtoAndCheckerMaker {
class FcOp : public operators::NetOp {
public:
DEFINE_OPERATOR_CTOR(FcOp, operators::NetOp)
void Init() override {
FcOp(const std::string &type, const VarNameMap &inputs,
const VarNameMap &outputs, const AttributeMap &attrs)
: NetOp(type, inputs, outputs, attrs) {
AddOp(OpRegistry::CreateOp("mul",
{{"X", {Input("X")}}, {"Y", {Input("W")}}},
{{"Out", {Output("mul_result")}}}, {}));
......
......@@ -19,14 +19,12 @@ namespace paddle {
namespace framework {
enum class OpArgType { IN, OUT };
using VarNameMap = OperatorBase::VarNameMap;
static VarNameMap TransOpArg(const OperatorBase* src_op,
const OpArgType& src_type,
const OpArgType& dst_type, bool is_grad) {
static void TransOpArg(const OperatorBase* src_op,
OperatorBase::VarNameMap* vars,
const OpArgType& src_type, bool is_grad) {
const auto& src_inout =
src_type == OpArgType::IN ? src_op->Inputs() : src_op->Outputs();
VarNameMap dst_inout;
src_type == OpArgType::IN ? src_op->inputs_ : src_op->outputs_;
auto& dst_inout = *vars;
const OpProto& proto = OpProtos().at(src_op->Type());
const auto& src_arg_list =
......@@ -45,20 +43,22 @@ static VarNameMap TransOpArg(const OperatorBase* src_op,
}
OperatorBase* BuildGradOp(const OperatorBase* op) {
std::string grad_op_type = OpRegistry::grad_ops().at(op->Type());
auto I = TransOpArg(op, OpArgType::IN, OpArgType::IN, false); // I
auto O = TransOpArg(op, OpArgType::OUT, OpArgType::IN, false); // O
auto OG = TransOpArg(op, OpArgType::OUT, OpArgType::IN, true); // OG
auto IG = TransOpArg(op, OpArgType::IN, OpArgType::OUT, true); // IG
// TODO(merge I/O/OG)
VarNameMap GradIn;
GradIn.insert(I.begin(), I.end());
GradIn.insert(O.begin(), O.end());
GradIn.insert(OG.begin(), OG.end());
auto gop_type_it = OpRegistry::grad_ops().find(op->type_);
PADDLE_ENFORCE(gop_type_it != OpRegistry::grad_ops().end(),
"Operator %s do not register gradient type", op->type_);
auto& grad_op_type = gop_type_it->second;
OperatorBase::VarNameMap inputs;
OperatorBase::VarNameMap outputs;
TransOpArg(op, &inputs, OpArgType::IN, false); // I
TransOpArg(op, &inputs, OpArgType::OUT, false); // O
TransOpArg(op, &inputs, OpArgType::OUT, true); // OG
TransOpArg(op, &outputs, OpArgType::IN, true); // IG
auto gop_it = OpRegistry::op_creators().find(grad_op_type);
PADDLE_ENFORCE(gop_it != OpRegistry::op_creators().end(),
"Operator %s 's Gradient %s's creator cannot be found",
op->type_, grad_op_type);
OperatorBase* grad_op = OpRegistry::op_creators().at(grad_op_type)(
grad_op_type, GradIn, IG, op->Attrs());
return grad_op;
return gop_it->second(grad_op_type, inputs, outputs, op->attrs_);
}
} // namespace framework
......
......@@ -10,7 +10,7 @@ namespace framework {
class NOP : public OperatorBase {
public:
DEFINE_OPERATOR_CTOR(NOP, OperatorBase);
using OperatorBase::OperatorBase;
void InferShape(const Scope &scope) const override {}
void Run(const Scope &scope,
const platform::DeviceContext &dev_ctx) const override {}
......
......@@ -122,8 +122,8 @@ class OpProtoAndCheckerMaker {
class OpRegistry {
using VarNameMap = OperatorBase::VarNameMap;
using OpCreator = std::function<OperatorBase*(
const std::string& type, const VarNameMap& inputs,
const VarNameMap& outputs, const AttributeMap& attrs)>;
const std::string& /*type*/, const VarNameMap& /*inputs*/,
const VarNameMap& /*outputs*/, const AttributeMap& /*attrs*/)>;
public:
template <typename OpType, typename ProtoMakerType>
......@@ -158,7 +158,7 @@ class OpRegistry {
static std::shared_ptr<OperatorBase> CreateOp(const std::string& type,
const VarNameMap& inputs,
const VarNameMap& outputs,
const AttributeMap& attrs) {
AttributeMap attrs) {
auto op_create_it = op_creators().find(type);
PADDLE_ENFORCE(op_create_it != op_creators().end(),
"Operator %s cannot be found.", type);
......@@ -168,7 +168,6 @@ class OpRegistry {
auto op = op_create_it->second(type, inputs, outputs, attrMap);
GenerateTempVariableName(op);
op->Init();
return std::shared_ptr<OperatorBase>(op);
}
......@@ -200,7 +199,6 @@ class OpRegistry {
PADDLE_ENFORCE(!op.IsNetOp(),
"Use framework::Backward to get backward ops");
std::shared_ptr<OperatorBase> grad_op(BuildGradOp(&op));
grad_op->Init();
return grad_op;
}
......
......@@ -7,7 +7,7 @@ namespace paddle {
namespace framework {
class CosineOp : public OperatorBase {
public:
DEFINE_OPERATOR_CTOR(CosineOp, OperatorBase);
using OperatorBase::OperatorBase;
void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {}
void InferShape(const Scope& scope) const override {}
......@@ -28,7 +28,7 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
class MyTestOp : public OperatorBase {
public:
DEFINE_OPERATOR_CTOR(MyTestOp, OperatorBase);
using OperatorBase::OperatorBase;
void InferShape(const Scope& scope) const override {}
void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {}
......
......@@ -122,6 +122,23 @@ void OperatorBase::Rename(const std::string& old_name,
}
}
OperatorBase::OperatorBase(const std::string& type,
const OperatorBase::VarNameMap& inputs,
const OperatorBase::VarNameMap& outputs,
const AttributeMap& attrs)
: type_(type), inputs_(inputs), outputs_(outputs), attrs_(attrs) {
static std::atomic<size_t> gUniqId(0UL);
for (auto& output : outputs_) {
for (auto& output_name : output.second) {
if (output_name == kTempVarName) {
output_name += type_;
output_name += "@";
output_name += std::to_string(gUniqId.fetch_add(1));
}
}
}
}
std::vector<std::string> OperatorBase::OutputVars(bool has_intermediate) const {
std::vector<std::string> ret_val;
if (has_intermediate) {
......
......@@ -66,10 +66,8 @@ class OperatorBase {
public:
using VarNameMap = std::map<std::string, std::vector<std::string>>;
OperatorBase() = default;
OperatorBase(const std::string& type, const VarNameMap& inputs,
const VarNameMap& outputs, const AttributeMap& attrs)
: type_(type), inputs_(inputs), outputs_(outputs), attrs_(attrs) {}
const VarNameMap& outputs, const AttributeMap& attrs);
OperatorBase(const OperatorBase& o) = delete;
OperatorBase& operator=(const OperatorBase& o) = delete;
......@@ -86,10 +84,6 @@ class OperatorBase {
virtual std::string DebugString() const;
/// Init will be called after CreateOperator, you can put some initialization
/// logic here.
virtual void Init() {}
/// InferShape infer the size of Variables used by this Operator with
/// information inside scope
virtual void InferShape(const Scope& scope) const = 0;
......@@ -138,15 +132,6 @@ class OperatorBase {
AttributeMap attrs_;
};
#define DEFINE_OPERATOR_CTOR(Class, ParentClass) \
public: \
Class() : ParentClass() { /* TODO(yi): This constructor is to be removed. */ \
} \
Class(const std::string& type, const VarNameMap& inputs, \
const VarNameMap& outputs, \
const paddle::framework::AttributeMap& attrs) \
: ParentClass(type, inputs, outputs, attrs) {}
class InferShapeContext {
public:
InferShapeContext(const OperatorBase& op, const Scope& scope)
......@@ -267,6 +252,10 @@ class ExecutionContext : public InferShapeContext {
platform::Place GetPlace() const { return device_context_->GetPlace(); }
const platform::DeviceContext* device_context() const {
return device_context_;
}
const platform::DeviceContext* device_context_;
};
......@@ -286,8 +275,6 @@ class OpKernel {
class OperatorWithKernel : public OperatorBase {
public:
DEFINE_OPERATOR_CTOR(OperatorWithKernel, OperatorBase)
struct OpKernelKey {
platform::Place place_;
......@@ -311,6 +298,10 @@ class OperatorWithKernel : public OperatorBase {
using OpKernelMap =
std::unordered_map<OpKernelKey, std::unique_ptr<OpKernel>, OpKernelHash>;
OperatorWithKernel(const std::string& type, const VarNameMap& inputs,
const VarNameMap& outputs, const AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void InferShape(const Scope& scope) const override {
InferShape(InferShapeContext(*this, scope));
}
......
......@@ -22,10 +22,10 @@ namespace framework {
static int op_run_num = 0;
class OpWithoutKernelTest : public OperatorBase {
DEFINE_OPERATOR_CTOR(OpWithoutKernelTest, framework::OperatorBase)
public:
void Init() override { x = 1; }
OpWithoutKernelTest(const std::string& type, const VarNameMap& inputs,
const VarNameMap& outputs, const AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs), x(1) {}
void InferShape(const Scope& scope) const override {}
void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {
......@@ -38,7 +38,7 @@ class OpWithoutKernelTest : public OperatorBase {
}
public:
float x = 0;
int x{0};
};
class OpeWithoutKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
......@@ -109,7 +109,9 @@ class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
static int cpu_kernel_run_num = 0;
class OpWithKernelTest : public OperatorWithKernel {
DEFINE_OPERATOR_CTOR(OpWithKernelTest, framework::OperatorWithKernel)
public:
using OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext& ctx) const override {}
};
......
......@@ -105,6 +105,8 @@ class Tensor {
template <typename T>
inline Tensor Slice(const int& begin_idx, const int& end_idx) const;
platform::Place place() const { return holder_->place(); }
private:
template <typename T>
inline void check_memory_size() const;
......
......@@ -41,6 +41,7 @@ function(op_library TARGET)
endif()
endfunction()
add_subdirectory(math)
cc_test(gather_test SRCS gather_test.cc DEPS tensor)
cc_library(net_op SRCS net_op.cc DEPS op_registry)
......@@ -50,7 +51,7 @@ op_library(add_op SRCS add_op.cc add_op.cu)
op_library(mean_op SRCS mean_op.cc mean_op.cu)
op_library(mul_op SRCS mul_op.cc mul_op.cu)
op_library(mul_op SRCS mul_op.cc mul_op.cu DEPS math_function)
op_library(rowwise_add_op SRCS rowwise_add_op.cu rowwise_add_op.cc)
op_library(sigmoid_op SRCS sigmoid_op.cc sigmoid_op.cu)
......
......@@ -18,7 +18,8 @@ namespace paddle {
namespace operators {
class AddOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(AddOp, framework::OperatorWithKernel)
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
......@@ -45,7 +46,9 @@ The equation is: Out = X + Y
};
class AddOpGrad : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(AddOpGrad, framework::OperatorWithKernel)
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {}
};
......
......@@ -18,7 +18,9 @@ namespace paddle {
namespace operators {
class OnehotCrossEntropyOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(OnehotCrossEntropyOp, framework::OperatorWithKernel)
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
auto *X = ctx.Input<Tensor>("X");
......@@ -32,8 +34,9 @@ class OnehotCrossEntropyOp : public framework::OperatorWithKernel {
};
class OnehotCrossEntropyGradientOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(OnehotCrossEntropyGradientOp,
framework::OperatorWithKernel)
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
auto X_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
......
......@@ -18,7 +18,8 @@ namespace paddle {
namespace operators {
class FillZerosLikeOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(FillZerosLikeOp, framework::OperatorWithKernel);
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
......
......@@ -43,7 +43,8 @@ class GaussianRandomKernel : public framework::OpKernel {
};
class GaussianRandomOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(GaussianRandomOp, framework::OperatorWithKernel);
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext& context) const override {
......
if(WITH_MKLML)
set(BLAS_LIB mklml)
else()
set(BLAS_LIB cblas)
endif()
if(WITH_GPU)
nv_library(math_function SRCS math_function.cc math_function.cu DEPS ${BLAS_LIB} device_context)
else()
cc_library(math_function SRCS math_function.cc DEPS ${BLAS_LIB} device_context)
endif()
nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/math_function.h"
namespace paddle {
namespace operators {
namespace math {
template <>
void gemm<platform::CPUPlace, float>(const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M,
const int N, const int K,
const float alpha, const float* A,
const float* B, const float beta, float* C,
platform::DeviceContext* context) {
int lda = K;
int ldb = N;
int ldc = N;
cblas_sgemm(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
beta, C, ldc);
}
template <>
void gemm<platform::CPUPlace, double>(const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M,
const int N, const int K,
const double alpha, const double* A,
const double* B, const double beta,
double* C,
platform::DeviceContext* context) {
int lda = K;
int ldb = N;
int ldc = N;
cblas_dgemm(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
beta, C, ldc);
}
template <>
void matmul<platform::CPUPlace, float>(const framework::Tensor& matrix_a,
bool trans_a,
const framework::Tensor& matrix_b,
bool trans_b, float alpha,
framework::Tensor* matrix_out,
float beta,
platform::DeviceContext* context) {
auto dim_a = matrix_a.dims();
auto dim_b = matrix_b.dims();
auto dim_out = matrix_out->dims();
PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2,
"The input and output of matmul be matrix");
PADDLE_ENFORCE(platform::is_cpu_place(matrix_a.place()) &&
platform::is_cpu_place(matrix_b.place()) &&
platform::is_cpu_place(matrix_out->place()),
"Matrix must all be in CPUPlace");
int M = dim_out[0];
int N = dim_out[1];
int K = (trans_a == false) ? dim_a[1] : dim_a[0];
CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;
gemm<platform::CPUPlace, float>(
transA, transB, M, N, K, alpha, matrix_a.data<float>(),
matrix_b.data<float>(), beta, matrix_out->data<float>(), context);
}
template <>
void matmul<platform::CPUPlace, double>(const framework::Tensor& matrix_a,
bool trans_a,
const framework::Tensor& matrix_b,
bool trans_b, double alpha,
framework::Tensor* matrix_out,
double beta,
platform::DeviceContext* context) {
auto dim_a = matrix_a.dims();
auto dim_b = matrix_b.dims();
auto dim_out = matrix_out->dims();
PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2,
"The input and output of matmul be matrix");
PADDLE_ENFORCE(platform::is_cpu_place(matrix_a.place()) &&
platform::is_cpu_place(matrix_b.place()) &&
platform::is_cpu_place(matrix_out->place()),
"Matrix must all be in CPUPlace");
int M = dim_out[0];
int N = dim_out[1];
int K = (trans_a == false) ? dim_a[1] : dim_a[0];
CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;
gemm<platform::CPUPlace, double>(
transA, transB, M, N, K, alpha, matrix_a.data<double>(),
matrix_b.data<double>(), beta, matrix_out->data<double>(), context);
}
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/math_function.h"
namespace paddle {
namespace operators {
namespace math {
template <>
void gemm<platform::GPUPlace, float>(const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M,
const int N, const int K,
const float alpha, const float* A,
const float* B, const float beta, float* C,
platform::DeviceContext* context) {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
cublasOperation_t cuTransA =
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
PADDLE_ENFORCE(platform::dynload::cublasSgemm(
reinterpret_cast<platform::CUDADeviceContext*>(context)->cublas_handle(),
cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, N));
}
template <>
void gemm<platform::GPUPlace, double>(const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M,
const int N, const int K,
const double alpha, const double* A,
const double* B, const double beta,
double* C,
platform::DeviceContext* context) {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
cublasOperation_t cuTransA =
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
PADDLE_ENFORCE(platform::dynload::cublasDgemm(
reinterpret_cast<platform::CUDADeviceContext*>(context)->cublas_handle(),
cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, N));
}
template <>
void matmul<platform::GPUPlace, float>(const framework::Tensor& matrix_a,
bool trans_a,
const framework::Tensor& matrix_b,
bool trans_b, float alpha,
framework::Tensor* matrix_out,
float beta,
platform::DeviceContext* context) {
auto dim_a = matrix_a.dims();
auto dim_b = matrix_b.dims();
auto dim_out = matrix_out->dims();
PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2,
"The input and output of matmul be matrix");
PADDLE_ENFORCE(platform::is_gpu_place(matrix_a.place()) &&
platform::is_gpu_place(matrix_b.place()) &&
platform::is_gpu_place(matrix_out->place()),
"Matrix must all be in GPUPlace");
int M = dim_out[0];
int N = dim_out[1];
int K = (trans_a == false) ? dim_a[1] : dim_a[0];
CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;
gemm<platform::GPUPlace, float>(
transA, transB, M, N, K, alpha, matrix_a.data<float>(),
matrix_b.data<float>(), beta, matrix_out->data<float>(), context);
}
template <>
void matmul<platform::GPUPlace, double>(const framework::Tensor& matrix_a,
bool trans_a,
const framework::Tensor& matrix_b,
bool trans_b, double alpha,
framework::Tensor* matrix_out,
double beta,
platform::DeviceContext* context) {
auto dim_a = matrix_a.dims();
auto dim_b = matrix_b.dims();
auto dim_out = matrix_out->dims();
PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2,
"The input and output of matmul be matrix");
PADDLE_ENFORCE(platform::is_gpu_place(matrix_a.place()) &&
platform::is_gpu_place(matrix_b.place()) &&
platform::is_gpu_place(matrix_out->place()),
"Matrix must all be in GPUPlace");
int M = dim_out[0];
int N = dim_out[1];
int K = (trans_a == false) ? dim_a[1] : dim_a[0];
CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;
gemm<platform::GPUPlace, double>(
transA, transB, M, N, K, alpha, matrix_a.data<double>(),
matrix_b.data<double>(), beta, matrix_out->data<double>(), context);
}
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifdef PADDLE_USE_MKLML
#include <mkl_cblas.h>
#include <mkl_lapacke.h>
#include <mkl_vml_functions.h>
#endif
#ifdef PADDLE_USE_MKL
#include <mkl.h>
#include <mkl_lapacke.h>
#endif
#ifdef PADDLE_USE_ATLAS
extern "C" {
#include <cblas.h>
#include <clapack.h>
}
#endif
#ifdef PADDLE_USE_OPENBLAS
#include <cblas.h>
#include <lapacke.h>
#endif
#ifndef LAPACK_FOUND
extern "C" {
#include <cblas.h>
int LAPACKE_sgetrf(int matrix_layout, int m, int n, float* a, int lda,
int* ipiv);
int LAPACKE_dgetrf(int matrix_layout, int m, int n, double* a, int lda,
int* ipiv);
int LAPACKE_sgetri(int matrix_layout, int n, float* a, int lda,
const int* ipiv);
int LAPACKE_dgetri(int matrix_layout, int n, double* a, int lda,
const int* ipiv);
}
#endif
#include <cmath>
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/enforce.h"
namespace paddle {
namespace operators {
namespace math {
// Support continuous memory now
// If transA = N, and transB = N
// Then matrixA: M * K, matrixB: K * N matrixC : M * N
// For more detailed info, please refer to
// http://www.netlib.org/lapack/explore-html/d4/de2/sgemm_8f.html
template <typename Place, typename T>
void gemm(const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB,
const int M, const int N, const int K, const T alpha, const T* A,
const T* B, const T beta, T* C, platform::DeviceContext* context);
// matrix multiply with continuous memory
template <typename Place, typename T>
void matmul(const framework::Tensor& matrix_a, bool trans_a,
const framework::Tensor& matrix_b, bool trans_b, T alpha,
framework::Tensor* matrix_out, T beta,
platform::DeviceContext* context);
} // namespace math
} // namespace operators
} // namespace paddle
#include "paddle/operators/math/math_function.h"
#include "gtest/gtest.h"
#ifndef PADDLE_ONLY_CPU
TEST(math_function, notrans_mul_trans) {
paddle::framework::Tensor input1;
paddle::framework::Tensor input1_gpu;
paddle::framework::Tensor input2_gpu;
paddle::framework::Tensor out_gpu;
paddle::framework::Tensor out;
auto* cpu_place = new paddle::platform::CPUPlace();
float* input1_ptr = input1.mutable_data<float>({2, 3}, *cpu_place);
float arr[6] = {0, 1, 2, 3, 4, 5};
memcpy(input1_ptr, arr, 6 * sizeof(float));
auto* gpu_place = new paddle::platform::GPUPlace(0);
paddle::platform::DeviceContext* context =
new paddle::platform::CUDADeviceContext(*gpu_place);
input1_gpu.CopyFrom<float>(input1, *gpu_place);
input2_gpu.CopyFrom<float>(input1, *gpu_place);
out_gpu.mutable_data<float>({2, 2}, *gpu_place);
paddle::operators::math::matmul<paddle::platform::GPUPlace, float>(
input1_gpu, false, input2_gpu, true, 1, &out_gpu, 0, context);
out.CopyFrom<float>(out_gpu, *cpu_place);
float* out_ptr = out.data<float>();
EXPECT_EQ(out_ptr[0], 5);
EXPECT_EQ(out_ptr[1], 14);
EXPECT_EQ(out_ptr[2], 14);
EXPECT_EQ(out_ptr[3], 50);
}
TEST(math_function, trans_mul_notrans) {
paddle::framework::Tensor input1;
paddle::framework::Tensor input1_gpu;
paddle::framework::Tensor input2_gpu;
paddle::framework::Tensor out_gpu;
paddle::framework::Tensor out;
auto* cpu_place = new paddle::platform::CPUPlace();
float* input1_ptr = input1.mutable_data<float>({2, 3}, *cpu_place);
float arr[6] = {0, 1, 2, 3, 4, 5};
memcpy(input1_ptr, arr, 6 * sizeof(float));
auto* gpu_place = new paddle::platform::GPUPlace(0);
paddle::platform::DeviceContext* context =
new paddle::platform::CUDADeviceContext(*gpu_place);
input1_gpu.CopyFrom<float>(input1, *gpu_place);
input2_gpu.CopyFrom<float>(input1, *gpu_place);
out_gpu.mutable_data<float>({3, 3}, *gpu_place);
paddle::operators::math::matmul<paddle::platform::GPUPlace, float>(
input1_gpu, true, input2_gpu, false, 1, &out_gpu, 0, context);
out.CopyFrom<float>(out_gpu, *cpu_place);
float* out_ptr = out.data<float>();
EXPECT_EQ(out_ptr[0], 9);
EXPECT_EQ(out_ptr[1], 12);
EXPECT_EQ(out_ptr[2], 15);
EXPECT_EQ(out_ptr[3], 12);
EXPECT_EQ(out_ptr[4], 17);
EXPECT_EQ(out_ptr[5], 22);
EXPECT_EQ(out_ptr[6], 15);
EXPECT_EQ(out_ptr[7], 22);
EXPECT_EQ(out_ptr[8], 29);
}
#endif
......@@ -18,7 +18,9 @@ namespace paddle {
namespace operators {
class MeanOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(MeanOp, framework::OperatorWithKernel)
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
......@@ -38,7 +40,9 @@ class MeanOpMaker : public framework::OpProtoAndCheckerMaker {
};
class MeanGradOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(MeanGradOp, framework::OperatorWithKernel)
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
ctx.Output<Tensor>(framework::GradVarName("X"))
......
......@@ -13,12 +13,14 @@
limitations under the License. */
#include "paddle/operators/mul_op.h"
#include "paddle/operators/math/math_function.h"
namespace paddle {
namespace operators {
class MulOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(MulOp, framework::OperatorWithKernel);
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
......@@ -53,7 +55,9 @@ The equation is: Out = X * Y
};
class MulOpGrad : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(MulOpGrad, framework::OperatorWithKernel)
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {}
std::string DebugString() const override {
......
......@@ -16,5 +16,4 @@
#include "paddle/operators/mul_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(mul, ops::MulKernel<paddle::platform::GPUPlace, float>);
......@@ -13,6 +13,9 @@
limitations under the License. */
#pragma once
#include "paddle/operators/math/math_function.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
......
......@@ -81,5 +81,11 @@ std::vector<std::string> NetOp::OutputVars(bool has_intermediate) const {
return ret_val;
}
NetOp::NetOp(const std::string& type,
const framework::OperatorBase::VarNameMap& inputs,
const framework::OperatorBase::VarNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
} // namespace operators
} // namespace paddle
......@@ -37,7 +37,9 @@ namespace operators {
class NetOp : public framework::OperatorBase {
public:
static const char kAll[];
DEFINE_OPERATOR_CTOR(NetOp, framework::OperatorBase);
NetOp() : framework::OperatorBase("plain_net", {}, {}, {}) {}
NetOp(const std::string& type, const VarNameMap& inputs,
const VarNameMap& outputs, const framework::AttributeMap& attrs);
/**
* Infer all the operators' input and output variables' shapes, will be called
......
......@@ -12,7 +12,7 @@ static int run_cnt = 0;
class TestOp : public framework::OperatorBase {
public:
DEFINE_OPERATOR_CTOR(TestOp, framework::OperatorBase);
using framework::OperatorBase::OperatorBase;
void InferShape(const Scope& scope) const override { ++infer_shape_cnt; }
void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {
......@@ -22,7 +22,7 @@ class TestOp : public framework::OperatorBase {
class EmptyOp : public framework::OperatorBase {
public:
DEFINE_OPERATOR_CTOR(EmptyOp, framework::OperatorBase);
using framework::OperatorBase::OperatorBase;
void InferShape(const Scope& scope) const override {}
void Run(const Scope& scope, const DeviceContext& dev_ctx) const override {}
};
......@@ -44,14 +44,14 @@ TEST(OpKernel, all) {
auto net = std::make_shared<NetOp>();
ASSERT_NE(net, nullptr);
auto op1 = std::make_shared<TestOp>();
op1->inputs_ = {{"X", {"x"}}, {"W", {"w1"}}, {"b", {"b1"}}};
op1->outputs_ = {{"Out", {"y"}}};
auto op1 = std::shared_ptr<TestOp>(
new TestOp("test", {{"X", {"x"}}, {"W", {"w1"}}, {"b", {"b1"}}},
{{"Out", {"y"}}}, {}));
net->AddOp(op1);
auto op2 = std::make_shared<TestOp>();
op2->inputs_ = {{"X", {"y"}}, {"W", {"w2"}}, {"b", {"b2"}}};
op2->outputs_ = {{"Out", {"z"}}};
auto op2 = std::shared_ptr<TestOp>(
new TestOp("test", {{"X", {"y"}}, {"W", {"w2"}}, {"b", {"b2"}}},
{{"Out", {"z"}}}, {}));
net->AddOp(op2);
net->CompleteAddOp();
......@@ -67,9 +67,9 @@ TEST(OpKernel, all) {
TEST(NetOp, insert_op) {
NetOp net;
auto op1 = std::make_shared<EmptyOp>();
op1->inputs_ = {{"X", {"x"}}, {"W", {"w1"}}, {"b", {"b1"}}};
op1->outputs_ = {{"Out", {"y"}}};
auto op1 = std::shared_ptr<EmptyOp>(
new EmptyOp("empty", {{"X", {"x"}}, {"W", {"w1"}}, {"b", {"b1"}}},
{{"Out", {"y"}}}, {}));
net.AddOp(op1);
net.InsertOp(0, op1);
ASSERT_EQ(2UL, net.ops_.size());
......
......@@ -135,8 +135,11 @@ const rnn::ArgumentName RecurrentGradientOp::kArgName{
"inlink@grad", "inlink_alias", "outlink_alias",
"memories", "pre_memories", "boot_memories@grad"};
void RecurrentOp::Init() {
OperatorBase::Init();
RecurrentOp::RecurrentOp(const std::string& type,
const framework::OperatorBase::VarNameMap& inputs,
const framework::OperatorBase::VarNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {
std::unique_ptr<rnn::Argument> arg(new rnn::Argument());
rnn::InitArgument(kArgName, arg.get(), *this);
alg_.Init(std::move(arg));
......@@ -230,8 +233,11 @@ void RecurrentGradientAlgorithm::InferShape(const Scope& scope) const {
LinkBootMemoryGradients(step_scopes[0], true /*infer_shape_mode*/);
}
void RecurrentGradientOp::Init() {
OperatorBase::Init();
RecurrentGradientOp::RecurrentGradientOp(
const std::string& type, const framework::OperatorBase::VarNameMap& inputs,
const framework::OperatorBase::VarNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {
std::unique_ptr<rnn::Argument> arg(new rnn::Argument());
rnn::InitArgument(kArgName, arg.get(), *this);
alg_.Init(std::move(arg));
......
......@@ -101,13 +101,11 @@ class RecurrentGradientAlgorithm {
class RecurrentOp final : public framework::OperatorBase {
public:
DEFINE_OPERATOR_CTOR(RecurrentOp, framework::OperatorBase);
void Init() override;
RecurrentOp(const std::string& type, const VarNameMap& inputs,
const VarNameMap& outputs, const framework::AttributeMap& attrs);
/**
* InferShape must be called before Run.
*/
* InferShape must be called before Run.
*/
void InferShape(const framework::Scope& scope) const override {
alg_.InferShape(scope);
}
......@@ -125,8 +123,9 @@ class RecurrentOp final : public framework::OperatorBase {
class RecurrentGradientOp final : public framework::OperatorBase {
public:
DEFINE_OPERATOR_CTOR(RecurrentGradientOp, framework::OperatorBase)
void Init() override;
RecurrentGradientOp(const std::string& type, const VarNameMap& inputs,
const VarNameMap& outputs,
const framework::AttributeMap& attrs);
/**
* InferShape must be called before Run.
......
......@@ -18,7 +18,9 @@ namespace paddle {
namespace operators {
class RowWiseAddOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(RowWiseAddOp, framework::OperatorWithKernel)
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
auto dim0 = ctx.Input<Tensor>("X")->dims();
......
......@@ -18,7 +18,9 @@ namespace paddle {
namespace operators {
class SGDOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(SGDOp, framework::OperatorWithKernel)
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE(
......
......@@ -18,7 +18,9 @@ namespace paddle {
namespace operators {
class SigmoidOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(SigmoidOp, framework::OperatorWithKernel)
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
ctx.Output<Tensor>("Y")->Resize(ctx.Input<Tensor>("X")->dims());
......@@ -37,7 +39,9 @@ class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
};
class SigmoidOpGrad : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(SigmoidOpGrad, framework::OperatorWithKernel)
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
ctx.Output<Tensor>(0)->Resize(ctx.Input<Tensor>(0)->dims());
......
......@@ -18,7 +18,9 @@ namespace paddle {
namespace operators {
class SoftmaxOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(SoftmaxOp, framework::OperatorWithKernel)
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.Input<Tensor>("X")->dims().size() == 2UL,
......@@ -39,7 +41,9 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
};
class SoftmaxOpGrad : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(SoftmaxOpGrad, framework::OperatorWithKernel)
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputVar("Y") != nullptr, "Input(Y) should not be null");
......
......@@ -46,7 +46,9 @@ class CPUUniformRandomKernel : public framework::OpKernel {
};
class UniformRandomOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(UniformRandomOp, framework::OperatorWithKernel)
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext& ctx) const override {
PADDLE_ENFORCE(GetAttr<float>("min") < GetAttr<float>("max"),
......
......@@ -62,12 +62,12 @@ extern void *cublas_dso_handle;
DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(__name)
#define CUBLAS_BLAS_ROUTINE_EACH(__macro) \
__macro(cublasSgemv); \
__macro(cublasDgemv); \
__macro(cublasSgemm); \
__macro(cublasDgemm); \
__macro(cublasSgeam); \
__macro(cublasDgeam); \
__macro(cublasSgemv_v2); \
__macro(cublasDgemv_v2); \
__macro(cublasSgemm_v2); \
__macro(cublasDgemm_v2); \
__macro(cublasSgeam_v2); \
__macro(cublasDgeam_v2); \
__macro(cublasCreate_v2); \
__macro(cublasDestroy_v2); \
__macro(cublasSetStream_v2); \
......
......@@ -14,14 +14,21 @@ limitations under the License. */
#pragma once
#include <execinfo.h>
#include <dlfcn.h> // for dladdr
#include <execinfo.h> // for backtrace
#include <iomanip>
#include <memory>
#include <sstream>
#include <stdexcept>
#include <string>
#include "paddle/string/printf.h"
#include "paddle/string/to_string.h"
#ifdef __GNUC__
#include <cxxabi.h> // for __cxa_demangle
#endif
#ifndef PADDLE_ONLY_CPU
#include "paddle/platform/dynload/cublas.h"
......@@ -39,6 +46,19 @@ limitations under the License. */
namespace paddle {
namespace platform {
namespace {
#ifdef __GNUC__
inline std::string demangle(std::string name) {
int status = -4; // some arbitrary value to eliminate the compiler warning
std::unique_ptr<char, void (*)(void*)> res{
abi::__cxa_demangle(name.c_str(), NULL, NULL, &status), std::free};
return (status == 0) ? res.get() : name;
}
#else
inline std::string demangle(std::string name) { return name; }
#endif
}
struct EnforceNotMet : public std::exception {
std::exception_ptr exp_;
std::string err_str_;
......@@ -48,15 +68,29 @@ struct EnforceNotMet : public std::exception {
std::rethrow_exception(exp_);
} catch (const std::exception& exp) {
std::ostringstream sout;
sout << string::Sprintf("%s at [%s:%d]", exp.what(), f, l) << std::endl;
sout << "Call Stacks: " << std::endl;
sout << "PaddlePaddle Call Stacks: " << std::endl;
void* call_stack[TRACE_STACK_LIMIT];
int sz = backtrace(call_stack, TRACE_STACK_LIMIT);
auto line = backtrace_symbols(call_stack, sz);
for (int i = 0; i < sz; ++i) {
sout << line[i] << std::endl;
auto size = backtrace(call_stack, TRACE_STACK_LIMIT);
auto symbols = backtrace_symbols(call_stack, size);
Dl_info info;
for (int i = 0; i < size; ++i) {
if (dladdr(call_stack[i], &info)) {
auto demangled = demangle(info.dli_sname);
auto addr_offset = static_cast<char*>(call_stack[i]) -
static_cast<char*>(info.dli_saddr);
sout << string::Sprintf("%-3d %*0p %s + %zd\n", i,
2 + sizeof(void*) * 2, call_stack[i],
demangled, addr_offset);
} else {
sout << string::Sprintf("%-3d %*0p %s\n", i, 2 + sizeof(void*) * 2,
call_stack[i]);
}
}
free(line);
free(symbols);
err_str_ = sout.str();
}
}
......@@ -170,7 +204,7 @@ inline void throw_on_error(T e) {
* PADDLE_ENFORCE_EQ(a, b);
*
* will raise an expression described as follows:
* "enforce a == b failed, 1 != 2" with detailed stack infomation.
* "enforce a == b failed, 1 != 2" with detailed stack information.
*
* extra messages is also supported, for example:
* PADDLE_ENFORCE(a, b, "some simple enforce failed between %d numbers", 2)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册