未验证 提交 f649442d 编写于 作者: C Chen Weihang 提交者: GitHub

New custom operator extension mechanism (#30690)

* initial commit: simple demo

* polish copyright format

* add grap op simple demo

* adapt uncertain number of argument

* change trait marco name

* add place & dtype support for add kernel

* add dispath and infershape func

* poish code & add notes

* add dynamic_loader dep for paddle_framework

* add new custom op test dir

* polish impl details

* add unittest for new custom op

* fix failed unittest

* Costum op (#1)

* fix compile error

* wrap framework tensor with LoDTensor

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* add CustomTensor default constructor

* add size() for CustomTensor

* make size const for CustomTensor

* refactor place related api to circle the concept

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* make place const

* make Tensor copy

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* remove additional head of framework

* use back to shared ptr for custom tensor

* use back to shared ptr for custom tensor

* use back to shared ptr for custom tensor

* use back to shared ptr for custom tensor

* use back to shared ptr for custom tensor

* use back to shared ptr for custom tensor

* add gpu test

* merge latest cwh code in

* adjust ut code of custom op

* adjust ut code of custom op

* adjust ut code of custom op

* Remove ShareData from user && Change CustomTensor to Tensor && Support more data type (#2)

* fix compile error

* wrap framework tensor with LoDTensor

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* add CustomTensor default constructor

* add size() for CustomTensor

* make size const for CustomTensor

* refactor place related api to circle the concept

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* make place const

* make Tensor copy

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* remove additional head of framework

* use back to shared ptr for custom tensor

* use back to shared ptr for custom tensor

* use back to shared ptr for custom tensor

* use back to shared ptr for custom tensor

* use back to shared ptr for custom tensor

* use back to shared ptr for custom tensor

* add gpu test

* merge latest cwh code in

* adjust ut code of custom op

* adjust ut code of custom op

* adjust ut code of custom op

* adjust ut code of custom op

* adjust ut code of custom op

* hid share data from and to

* rename CustomTensor to Tensor

* refactor register design & add test

* change op_funtion to op_meta_info

* split op meta info into .h and .cc

* move get methods into friend class

* move OpMetaInfoHelper into framework space

* move CustomTensorUtils into framework space

* change pybind api name

* move PD C API into op meta info

* add register custom op api

* remove inference cmake change

* refactor copy to api && change Reshape to lowercase && support more dtype && add more test (#3)

* fix compile error

* wrap framework tensor with LoDTensor

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* add CustomTensor default constructor

* add size() for CustomTensor

* make size const for CustomTensor

* refactor place related api to circle the concept

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* make place const

* make Tensor copy

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* remove additional head of framework

* use back to shared ptr for custom tensor

* use back to shared ptr for custom tensor

* use back to shared ptr for custom tensor

* use back to shared ptr for custom tensor

* use back to shared ptr for custom tensor

* use back to shared ptr for custom tensor

* add gpu test

* merge latest cwh code in

* adjust ut code of custom op

* adjust ut code of custom op

* adjust ut code of custom op

* adjust ut code of custom op

* adjust ut code of custom op

* hid share data from and to

* rename CustomTensor to Tensor

* support multi dtype

* remove lod, make reshape lowercase, add copy test and refactor copy api

* remove lod, make reshape lowercase, add copy test and refactor copy api

* remove lod, make reshape lowercase, add copy test and refactor copy api

* remove lod, make reshape lowercase, add copy test and refactor copy api

* fix copy to error

* add more test

* add more test

* add more test

* add more test

* add more test

* add more test

* add more test

* add more test

* add more test

* add more test

* add more test

* add more test

* add more test

* add more test

* add more test

* add more test

* polish detail & error message

* polish test details

* Add cast api && Change copy related api to copy_to && add more test (#4)

* fix compile error

* wrap framework tensor with LoDTensor

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* add CustomTensor default constructor

* add size() for CustomTensor

* make size const for CustomTensor

* refactor place related api to circle the concept

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* make place const

* make Tensor copy

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* remove additional head of framework

* use back to shared ptr for custom tensor

* use back to shared ptr for custom tensor

* use back to shared ptr for custom tensor

* use back to shared ptr for custom tensor

* use back to shared ptr for custom tensor

* use back to shared ptr for custom tensor

* add gpu test

* merge latest cwh code in

* adjust ut code of custom op

* adjust ut code of custom op

* adjust ut code of custom op

* adjust ut code of custom op

* adjust ut code of custom op

* hid share data from and to

* rename CustomTensor to Tensor

* support multi dtype

* remove lod, make reshape lowercase, add copy test and refactor copy api

* remove lod, make reshape lowercase, add copy test and refactor copy api

* remove lod, make reshape lowercase, add copy test and refactor copy api

* remove lod, make reshape lowercase, add copy test and refactor copy api

* fix copy to error

* add more test

* add more test

* add more test

* add more test

* add more test

* add more test

* add more test

* add more test

* add more test

* add more test

* add more test

* add more test

* add more test

* add more test

* add more test

* add more test

* add type cast

* add cast and make copy to api

* add cast and make copy to api

* add cast and make copy to api

* add cast and make copy to api

* merge cwh code

* merge cwh code

* merge cwh code

* merge cwh code

* merge cwh code

* add more error log

* add more error log

* polish code

* used for test

* remove test comment

* remove test comment

* fix uint8 type error

* fix lost uint8 type error

* add test for coverage

* polish details by reviewer comments

* add prefix for DISABLE_COPY_AND_ASSIGN
Co-authored-by: NJiabin Yang <360788950@qq.com>
上级 5c033271
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
// All paddle apis in C++ frontend
#include "paddle/fluid/extension/include/all.h"
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#if !defined(_MSC_VER) && __cplusplus < 199711L
#error C++11 or later compatible compiler is required to use Paddle.
#endif
#include "paddle/fluid/extension/include/dispatch.h"
#include "paddle/fluid/extension/include/dtype.h"
#include "paddle/fluid/extension/include/op_meta_info.h"
#include "paddle/fluid/extension/include/place.h"
#include "paddle/fluid/extension/include/tensor.h"
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/extension/include/dtype.h"
namespace paddle {
#define PD_PRIVATE_CASE_TYPE_USING_HINT(NAME, enum_type, type, HINT, ...) \
case enum_type: { \
using HINT = type; \
__VA_ARGS__(); \
break; \
}
#define PD_PRIVATE_CASE_TYPE(NAME, enum_type, type, ...) \
PD_PRIVATE_CASE_TYPE_USING_HINT(NAME, enum_type, type, data_t, __VA_ARGS__)
#define PD_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
[&] { \
const auto& dtype = TYPE; \
switch (dtype) { \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT32, float, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT64, double, \
__VA_ARGS__) \
default: \
throw std::runtime_error("function not implemented for this type."); \
} \
}()
// TODD(chenweihang): implement other DISPATH macros in next PR
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
enum DataType {
FLOAT32,
FLOAT64,
BFLOAT16,
COMPLEX128,
COMPLEX64,
FLOAT16,
INT64,
INT32,
INT16,
UINT8,
INT8,
BOOL,
// TODO(JiabinYang) support more data types if needed.
};
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <unordered_map>
#include <vector>
#include <boost/any.hpp>
#include "paddle/fluid/extension/include/tensor.h"
/**
* Op Meta Info Related Define.
*
* Used to maintain operator core information.
*
*/
namespace paddle {
namespace framework {
class OpMetaInfoHelper;
} // namespace framework
using Tensor = paddle::Tensor;
#define PD_DISABLE_COPY_AND_ASSIGN(classname) \
private: \
classname(const classname&) = delete; \
classname(classname&&) = delete; \
classname& operator=(const classname&) = delete; \
classname& operator=(classname&&) = delete
///////////////// Util Define and Function ////////////////
inline std::string Grad(const std::string& var_name) {
std::string result;
result.reserve(var_name.size() + 5U);
result += var_name;
result += "@GRAD";
return result;
}
////////////////////// Kernel Function (PD_KERNEL) ////////////////////////
// Record Op kernel core function
using KernelFunc = std::vector<Tensor> (*)(std::vector<Tensor> inputs,
std::vector<boost::any> attrs);
template <typename T>
struct TypeTag {};
template <typename F, F f>
struct KernelFuncImpl;
template <typename Return, typename... Args, Return (*impl_fn)(Args...)>
struct KernelFuncImpl<Return (*)(Args...), impl_fn> {
static Return Compute(std::vector<Tensor> inputs,
std::vector<boost::any> attrs) {
return ComputeCallHelper<Args..., TypeTag<int>>::template Compute<0, 0>(
inputs, attrs);
}
private:
template <typename... RemainingArgs>
struct ComputeCallHelper;
// for Tensor input
template <typename... Tail>
struct ComputeCallHelper<const Tensor&, Tail...> {
template <int in_idx, int attr_idx, typename... PreviousArgs>
static Return Compute(std::vector<Tensor> inputs,
std::vector<boost::any> attrs,
const PreviousArgs&... pargs) {
static_assert(attr_idx == 0,
"Input tensor should appear before attributes.");
const Tensor& arg = inputs[in_idx];
return ComputeCallHelper<Tail...>::template Compute<in_idx + 1, attr_idx>(
inputs, attrs, pargs..., arg);
}
};
// TODO(chenweihang): add support for attribute input
// int attribute input (not used now)
template <typename... Tail>
struct ComputeCallHelper<int, Tail...> {
template <int in_idx, int attr_idx, typename... PreviousArgs>
static Return Compute(std::vector<Tensor> inputs,
std::vector<boost::any> attrs,
const PreviousArgs&... pargs) {
try {
int arg = boost::any_cast<int>(attrs[attr_idx]);
return ComputeCallHelper<Tail...>::template Compute<in_idx,
attr_idx + 1>(
inputs, attrs, pargs..., arg);
} catch (boost::bad_any_cast&) {
throw std::runtime_error(
"Attribute cast error in custom operator. Expected int value.");
}
}
};
// end: base template
template <typename T>
struct ComputeCallHelper<TypeTag<T>> {
template <int in_idx, int attr_idx>
static Return Compute(std::vector<Tensor> inputs,
std::vector<boost::any> attrs, const Args&... args) {
return impl_fn(args...);
}
};
};
#define PD_KERNEL(...) \
::paddle::KernelFuncImpl<decltype(&__VA_ARGS__), &__VA_ARGS__>::Compute
/////////////// InferShape Function (PD_INFER_SHAPE) ///////////////
// Record Op infershape core function
using InferShapeFunc = std::vector<std::vector<int64_t>> (*)(
std::vector<std::vector<int64_t>> input_shapes);
template <typename F, F f>
struct InferShapeFuncImpl;
template <typename Return, typename... Args, Return (*impl_fn)(Args...)>
struct InferShapeFuncImpl<Return (*)(Args...), impl_fn> {
static Return InferShape(std::vector<std::vector<int64_t>> input_shapes) {
return InferShapeCallHelper<Args..., TypeTag<int>>::template InferShape<0>(
input_shapes);
}
private:
template <typename... RemainingArgs>
struct InferShapeCallHelper;
// only one type input: std::vector<int64_t>
template <typename... Tail>
struct InferShapeCallHelper<std::vector<int64_t>, Tail...> {
template <int in_idx, typename... PreviousArgs>
static Return InferShape(std::vector<std::vector<int64_t>> input_shapes,
const PreviousArgs&... pargs) {
std::vector<int64_t> arg = input_shapes[in_idx];
return InferShapeCallHelper<Tail...>::template InferShape<in_idx + 1>(
input_shapes, pargs..., arg);
}
};
// end: base template
template <typename T>
struct InferShapeCallHelper<TypeTag<T>> {
template <int in_idx>
static Return InferShape(std::vector<std::vector<int64_t>> input_shapes,
const Args&... args) {
return impl_fn(args...);
}
};
};
#define PD_INFER_SHAPE(...) \
::paddle::InferShapeFuncImpl<decltype(&__VA_ARGS__), &__VA_ARGS__>::InferShape
/////////////// InferDataType Function (PD_INFER_DTYPE) ///////////////
// Record Op Infer dtype core function
using InferDtypeFunc =
std::vector<DataType> (*)(std::vector<DataType> input_dtypes);
template <typename F, F f>
struct InferDtypeFuncImpl;
template <typename Return, typename... Args, Return (*impl_fn)(Args...)>
struct InferDtypeFuncImpl<Return (*)(Args...), impl_fn> {
static Return InferDtype(std::vector<DataType> input_dtypes) {
return InferDtypeCallHelper<Args..., TypeTag<int>>::template InferDtype<0>(
input_dtypes);
}
private:
template <typename... RemainingArgs>
struct InferDtypeCallHelper;
// Only one type input now: DataType
template <typename... Tail>
struct InferDtypeCallHelper<DataType, Tail...> {
template <int in_idx, typename... PreviousArgs>
static Return InferDtype(std::vector<DataType> input_dtypes,
const PreviousArgs&... pargs) {
DataType arg = input_dtypes[in_idx];
return InferDtypeCallHelper<Tail...>::template InferDtype<in_idx + 1>(
input_dtypes, pargs..., arg);
}
};
// end: base template
template <typename T>
struct InferDtypeCallHelper<TypeTag<T>> {
template <int in_idx>
static Return InferDtype(std::vector<DataType> input_dtypes,
const Args&... args) {
return impl_fn(args...);
}
};
};
#define PD_INFER_DTYPE(...) \
::paddle::InferDtypeFuncImpl<decltype(&__VA_ARGS__), &__VA_ARGS__>::InferDtype
////////////////////// Op Meta Info //////////////////////
class OpMetaInfo {
public:
explicit OpMetaInfo(const std::string& op_name) : name_(op_name) {}
OpMetaInfo& Inputs(std::vector<std::string>&& inputs);
OpMetaInfo& Outputs(std::vector<std::string>&& outputs);
OpMetaInfo& SetKernelFn(KernelFunc&& func);
OpMetaInfo& SetInferShapeFn(InferShapeFunc&& func);
OpMetaInfo& SetInferDtypeFn(InferDtypeFunc&& func);
private:
friend class framework::OpMetaInfoHelper;
// 1. desc info
std::string name_;
std::vector<std::string> inputs_;
std::vector<std::string> outputs_;
std::vector<std::string> attrs_;
// 2. func info
KernelFunc kernel_fn_;
InferShapeFunc infer_shape_fn_;
InferDtypeFunc infer_dtype_fn_;
};
//////////////// Op Meta Info Map /////////////////
class OpMetaInfoMap {
public:
// this function's impl should keep in header file.
// if move to cc file, meta info can not be added
// into map
static OpMetaInfoMap& Instance() {
static OpMetaInfoMap g_custom_op_meta_info_map;
return g_custom_op_meta_info_map;
}
std::vector<OpMetaInfo>& operator[](const std::string& name);
const std::unordered_map<std::string, std::vector<OpMetaInfo>>& GetMap()
const;
private:
OpMetaInfoMap() = default;
std::unordered_map<std::string, std::vector<OpMetaInfo>> map_;
PD_DISABLE_COPY_AND_ASSIGN(OpMetaInfoMap);
};
//////////////// Op Meta Info Builder /////////////////
class OpMetaInfoBuilder {
public:
explicit OpMetaInfoBuilder(std::string&& name);
OpMetaInfoBuilder& Inputs(std::vector<std::string>&& inputs);
OpMetaInfoBuilder& Outputs(std::vector<std::string>&& outputs);
OpMetaInfoBuilder& SetKernelFn(KernelFunc&& func);
OpMetaInfoBuilder& SetInferShapeFn(InferShapeFunc&& func);
OpMetaInfoBuilder& SetInferDtypeFn(InferDtypeFunc&& func);
OpMetaInfoBuilder& SetBackwardOp(const std::string& bwd_op_name);
private:
// Forward Op name
std::string name_;
// Point to the currently constructed op meta info
OpMetaInfo* info_ptr_;
};
/////////////////////// Op register API /////////////////////////
// For inference: compile directly with framework
// Call after PD_BUILD_OPERATOR(...)
void RegisterAllCustomOperator();
/////////////////////// Op register Macro /////////////////////////
#define PD_BUILD_OPERATOR(op_name) \
static ::paddle::OpMetaInfoBuilder __op_meta_info_##__COUNTER__##__ = \
::paddle::OpMetaInfoBuilder(op_name)
} // namespace paddle
///////////////////// C API ///////////////////
#ifdef __cplusplus
extern "C" {
#endif
// C-API to get global OpMetaInfoMap.
paddle::OpMetaInfoMap& PD_GetOpMetaInfoMap();
#ifdef __cplusplus
}
#endif
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
namespace paddle {
// TODO(yangjiabin): Add other place support in next PR
enum class PlaceType { kUNK = -1, kCPU, kGPU };
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include <vector>
#include "paddle/fluid/extension/include/dtype.h"
#include "paddle/fluid/extension/include/place.h"
namespace paddle {
namespace framework {
class CustomTensorUtils;
} // namespace framework
class Tensor {
public:
/// \brief Construct a Tensor on None Place for CustomOp.
/// Generally it's only used for user to create Tensor.
explicit Tensor(const PlaceType& place);
/// \brief Reset the shape of the tensor.
/// Generally it's only used for the input tensor.
/// Reshape must be called before calling
/// mutable_data() or copy_from_cpu()
/// \param shape The shape to set.
void reshape(const std::vector<int>& shape);
/// \brief Get the memory pointer in CPU or GPU with
/// specific data type.
/// Please Reshape the tensor first before call this.
/// It's usually used to get input data pointer.
/// \param place The place of the tensor this will
/// override the original place of current tensor.
template <typename T>
T* mutable_data(const PlaceType& place);
/// \brief Get the memory pointer in CPU or GPU with
/// specific data type. Please Reshape the tensor
/// first before call this.It's usually used to get
/// input data pointer.
template <typename T>
T* mutable_data();
/// \brief Get the memory pointer directly.
/// It's usually used to get the output data pointer.
/// \return The tensor data buffer pointer.
template <typename T>
T* data() const;
/// \brief Copy the host memory to tensor data.
/// It's usually used to set the input tensor data.
/// \param PlaceType of target place, from which
/// the tensor will copy.
template <typename T>
Tensor copy_to(const PlaceType& place);
/// \brief Return the shape of the Tensor.
std::vector<int> shape() const;
/// \brief Return the data type of the tensor.
/// It's usually used to get the output tensor data type.
/// \return The data type of the tensor.
DataType type() const;
/// \brief Get the size of current tensor.
/// Use this method to get the size of tensor
/// \return int64_t.
int64_t size() const;
/// \brief Get the place of current tensor.
/// Use this method to get the place of tensor
/// \return Place.
const PlaceType& place() const;
/// \brief Cast datatype from one to another
Tensor cast(const DataType& target_type);
private:
friend class framework::CustomTensorUtils;
mutable std::shared_ptr<void> tensor_;
mutable PlaceType place_;
};
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/extension/include/op_meta_info.h"
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/custom_operator.h"
namespace paddle {
////////////////////// Op Meta Info //////////////////////
OpMetaInfo& OpMetaInfo::Inputs(std::vector<std::string>&& inputs) {
inputs_ = std::forward<std::vector<std::string>>(inputs);
return *this;
}
OpMetaInfo& OpMetaInfo::Outputs(std::vector<std::string>&& outputs) {
outputs_ = std::forward<std::vector<std::string>>(outputs);
return *this;
}
OpMetaInfo& OpMetaInfo::SetKernelFn(KernelFunc&& func) {
kernel_fn_ = std::forward<KernelFunc>(func);
return *this;
}
OpMetaInfo& OpMetaInfo::SetInferShapeFn(InferShapeFunc&& func) {
infer_shape_fn_ = std::forward<InferShapeFunc>(func);
return *this;
}
OpMetaInfo& OpMetaInfo::SetInferDtypeFn(InferDtypeFunc&& func) {
infer_dtype_fn_ = std::forward<InferDtypeFunc>(func);
return *this;
}
//////////////// Op Meta Info Map /////////////////
std::vector<OpMetaInfo>& OpMetaInfoMap::operator[](const std::string& name) {
return map_[name];
}
const std::unordered_map<std::string, std::vector<OpMetaInfo>>&
OpMetaInfoMap::GetMap() const {
return map_;
}
//////////////// Op Meta Info Builder /////////////////
OpMetaInfoBuilder::OpMetaInfoBuilder(std::string&& name) {
name_ = std::forward<std::string>(name);
auto& info_vector = OpMetaInfoMap::Instance()[name_];
auto op_meta = OpMetaInfo(name_);
info_vector.emplace_back(std::move(op_meta));
info_ptr_ = &(info_vector.back());
}
OpMetaInfoBuilder& OpMetaInfoBuilder::Inputs(
std::vector<std::string>&& inputs) {
info_ptr_->Inputs(std::forward<std::vector<std::string>>(inputs));
return *this;
}
OpMetaInfoBuilder& OpMetaInfoBuilder::Outputs(
std::vector<std::string>&& outputs) {
info_ptr_->Outputs(std::forward<std::vector<std::string>>(outputs));
return *this;
}
OpMetaInfoBuilder& OpMetaInfoBuilder::SetKernelFn(KernelFunc&& func) {
info_ptr_->SetKernelFn(std::forward<KernelFunc>(func));
return *this;
}
OpMetaInfoBuilder& OpMetaInfoBuilder::SetInferShapeFn(InferShapeFunc&& func) {
info_ptr_->SetInferShapeFn(std::forward<InferShapeFunc>(func));
return *this;
}
OpMetaInfoBuilder& OpMetaInfoBuilder::SetInferDtypeFn(InferDtypeFunc&& func) {
info_ptr_->SetInferDtypeFn(std::forward<InferDtypeFunc>(func));
return *this;
}
OpMetaInfoBuilder& OpMetaInfoBuilder::SetBackwardOp(
const std::string& bwd_op_name) {
auto& info_vector = OpMetaInfoMap::Instance()[name_];
auto op_meta = OpMetaInfo(bwd_op_name);
info_vector.emplace_back(std::move(op_meta));
info_ptr_ = &(info_vector.back());
return *this;
}
/////////////////////// Op register API /////////////////////////
void RegisterAllCustomOperator() {
auto& op_meta_info_map = OpMetaInfoMap::Instance();
framework::RegisterOperatorWithMetaInfoMap(op_meta_info_map);
}
} // namespace paddle
extern "C" {
paddle::OpMetaInfoMap& PD_GetOpMetaInfoMap() {
return paddle::OpMetaInfoMap::Instance();
}
} // end extern "C"
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/extension/include/tensor.h"
#include <utility>
#include "paddle/fluid/framework/custom_tensor_utils.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/transform.h"
namespace paddle {
template <typename InType, typename OutType>
struct CastDataTypeFunctor {
HOSTDEVICE inline OutType operator()(InType in) const {
return static_cast<OutType>(in);
}
};
template <typename InType>
struct CastDataType {
CastDataType(const framework::Tensor &in, framework::Tensor *out,
const platform::DeviceContext *ctx)
: in_(in), out_(out), ctx_(ctx) {}
const framework::Tensor in_;
framework::Tensor *out_;
const platform::DeviceContext *ctx_;
template <typename OutType>
void apply() {
auto *in_begin = in_.data<InType>();
auto *in_end = in_begin + in_.numel();
auto *out_begin = out_->mutable_data<OutType>(in_.place());
if (platform::is_cpu_place(in_.place())) {
platform::Transform<platform::CPUDeviceContext> trans;
auto *context = static_cast<const platform::CPUDeviceContext *>(ctx_);
trans(*context, in_begin, in_end, out_begin,
CastDataTypeFunctor<InType, OutType>());
#ifdef __NVCC__
} else if (platform::is_gpu_place(in_.place())) {
platform::Transform<platform::CUDADeviceContext> trans;
auto *context = static_cast<const platform::CUDADeviceContext *>(ctx_);
trans(*context, in_begin, in_end, out_begin,
CastDataTypeFunctor<InType, OutType>());
context->Wait();
#endif
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Place type is not supported when casting data type."));
}
}
};
template <typename T>
void GpuCopy(T *src, T *dst, PlaceType src_plc, PlaceType dst_plc,
int64_t ele_size) {
#ifdef PADDLE_WITH_CUDA
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
int device_num = paddle::platform::GetCurrentDeviceId();
platform::CUDAPlace gpu_place(device_num);
auto *dev_ctx =
static_cast<const platform::CUDADeviceContext *>(pool.Get(gpu_place));
if ((src_plc == PlaceType::kGPU) && (dst_plc == PlaceType::kCPU)) {
memory::Copy(platform::CPUPlace(), static_cast<void *>(dst), gpu_place, src,
ele_size, dev_ctx->stream());
} else if ((src_plc == PlaceType::kGPU) && (dst_plc == PlaceType::kGPU)) {
memory::Copy(gpu_place, static_cast<void *>(dst), gpu_place, src, ele_size,
dev_ctx->stream());
} else if ((src_plc == PlaceType::kCPU) && (dst_plc == PlaceType::kGPU)) {
memory::Copy(gpu_place, static_cast<void *>(dst), platform::CPUPlace(), src,
ele_size, dev_ctx->stream());
} else {
PADDLE_THROW(platform::errors::Unavailable(
"Only GPU related Copy can reach this func."));
}
cudaStreamSynchronize(dev_ctx->stream());
#endif
}
#define GET_CASTED_TENSOR \
if (!tensor_) { \
tensor_ = std::make_shared<framework::LoDTensor>(); \
} \
auto *tensor = static_cast<framework::LoDTensor *>(tensor_.get());
void Tensor::reshape(const std::vector<int> &shape) {
GET_CASTED_TENSOR
tensor->Resize(framework::make_ddim(shape));
}
Tensor::Tensor(const PlaceType &place)
: tensor_(std::make_shared<framework::LoDTensor>()), place_(place) {}
template <typename T>
T *Tensor::mutable_data(const PlaceType &place) {
place_ = place;
return mutable_data<T>();
}
template <typename T>
T *Tensor::mutable_data() {
GET_CASTED_TENSOR
PADDLE_ENFORCE_GT(
tensor->numel(), 0,
platform::errors::PreconditionNotMet(
"You should call Tensor::Reshape(const std::vector<int> "
"&shape)"
"function before retrieving mutable_data from input tensor."));
switch (static_cast<int>(place_)) {
case static_cast<int>(PlaceType::kCPU): {
return tensor->mutable_data<T>(platform::CPUPlace());
}
#ifdef PADDLE_WITH_CUDA
case static_cast<int>(PlaceType::kGPU): {
int device_num = platform::GetCurrentDeviceId();
VLOG(1) << "Custom Operator: mutable data cuda device id - "
<< device_num;
return tensor->mutable_data<T>(platform::CUDAPlace(device_num));
}
#endif
default:
PADDLE_THROW(platform::errors::Unavailable(
"Custom operator unsupported place id(%d)",
static_cast<int>(place_)));
}
}
template <typename T>
T *Tensor::data() const {
GET_CASTED_TENSOR;
auto *res = tensor->data<T>();
return res;
}
DataType Tensor::type() const {
GET_CASTED_TENSOR;
auto type = tensor->type();
if (type == framework::proto::VarType::FP32) {
return DataType::FLOAT32;
} else if (type == framework::proto::VarType::INT64) {
return DataType::INT64;
} else if (type == framework::proto::VarType::INT32) {
return DataType::INT32;
} else if (type == framework::proto::VarType::INT16) {
return DataType::INT16;
} else if (type == framework::proto::VarType::INT8) {
return DataType::INT8;
} else if (type == framework::proto::VarType::UINT8) {
return DataType::UINT8;
} else if (type == framework::proto::VarType::FP64) {
return DataType::FLOAT64;
} else if (type == framework::proto::VarType::BF16) {
return DataType::BFLOAT16;
} else if (type == framework::proto::VarType::FP16) {
return DataType::FLOAT16;
} else if (type == framework::proto::VarType::COMPLEX64) {
return DataType::COMPLEX64;
} else if (type == framework::proto::VarType::COMPLEX128) {
return DataType::COMPLEX128;
} else if (type == framework::proto::VarType::BOOL) {
return DataType::BOOL;
}
return DataType::FLOAT32;
}
template <typename T>
Tensor Tensor::copy_to(const PlaceType &target_place) {
GET_CASTED_TENSOR;
PADDLE_ENFORCE_GE(tensor->numel(), 0,
platform::errors::PreconditionNotMet(
"You should call Tensor::Reshape(const "
"std::vector<int> &shape)"
"function before copying data from cpu."));
size_t ele_size = tensor->numel() * sizeof(T);
auto *p_src_data = tensor->data<T>();
auto src_place = place();
Tensor target = Tensor(target_place);
target.reshape(shape());
auto *p_target_data = target.template mutable_data<T>();
if ((src_place == PlaceType::kCPU) && (target_place == PlaceType::kCPU)) {
std::memcpy(static_cast<void *>(p_target_data), p_src_data, ele_size);
} else if ((src_place == PlaceType::kGPU) &&
(target_place == PlaceType::kCPU)) {
GpuCopy<T>(p_src_data, p_target_data, src_place, target_place, ele_size);
} else if ((src_place == PlaceType::kCPU) &&
(target_place == PlaceType::kGPU)) {
GpuCopy<T>(p_src_data, p_target_data, src_place, target_place, ele_size);
} else if ((src_place == PlaceType::kGPU) &&
(target_place == PlaceType::kGPU)) {
GpuCopy<T>(p_src_data, p_target_data, src_place, target_place, ele_size);
} else {
PADDLE_THROW(platform::errors::Unavailable(
"Not supported place transform of place: %d to place: %d",
static_cast<int>(src_place), static_cast<int>(target_place)));
}
return target;
}
template Tensor Tensor::copy_to<paddle::platform::float16>(
const PlaceType &target_place);
template Tensor Tensor::copy_to<paddle::platform::bfloat16>(
const PlaceType &target_place);
template Tensor Tensor::copy_to<paddle::platform::complex64>(
const PlaceType &target_place);
template Tensor Tensor::copy_to<paddle::platform::complex128>(
const PlaceType &target_place);
template Tensor Tensor::copy_to<float>(const PlaceType &target_place);
template Tensor Tensor::copy_to<double>(const PlaceType &target_place);
template Tensor Tensor::copy_to<int64_t>(const PlaceType &target_place);
template Tensor Tensor::copy_to<int32_t>(const PlaceType &target_place);
template Tensor Tensor::copy_to<uint8_t>(const PlaceType &target_place);
template Tensor Tensor::copy_to<int8_t>(const PlaceType &target_place);
template Tensor Tensor::copy_to<int16_t>(const PlaceType &target_place);
template Tensor Tensor::copy_to<bool>(const PlaceType &target_place);
template float *Tensor::data<float>() const;
template double *Tensor::data<double>() const;
template int64_t *Tensor::data<int64_t>() const;
template int32_t *Tensor::data<int32_t>() const;
template uint8_t *Tensor::data<uint8_t>() const;
template int8_t *Tensor::data<int8_t>() const;
template paddle::platform::float16 *Tensor::data<paddle::platform::float16>()
const;
template paddle::platform::bfloat16 *Tensor::data<paddle::platform::bfloat16>()
const;
template paddle::platform::complex128 *
Tensor::data<paddle::platform::complex128>() const;
template paddle::platform::complex64 *
Tensor::data<paddle::platform::complex64>() const;
template int16_t *Tensor::data<int16_t>() const;
template bool *Tensor::data<bool>() const;
template float *Tensor::mutable_data<float>();
template double *Tensor::mutable_data<double>();
template int64_t *Tensor::mutable_data<int64_t>();
template int32_t *Tensor::mutable_data<int32_t>();
template uint8_t *Tensor::mutable_data<uint8_t>();
template int8_t *Tensor::mutable_data<int8_t>();
template paddle::platform::float16 *
Tensor::mutable_data<paddle::platform::float16>();
template paddle::platform::bfloat16 *
Tensor::mutable_data<paddle::platform::bfloat16>();
template paddle::platform::complex128 *
Tensor::mutable_data<paddle::platform::complex128>();
template paddle::platform::complex64 *
Tensor::mutable_data<paddle::platform::complex64>();
template int16_t *Tensor::mutable_data<int16_t>();
template bool *Tensor::mutable_data<bool>();
template float *Tensor::mutable_data<float>(const PlaceType &place);
template double *Tensor::mutable_data<double>(const PlaceType &place);
template int64_t *Tensor::mutable_data<int64_t>(const PlaceType &place);
template int32_t *Tensor::mutable_data<int32_t>(const PlaceType &place);
template uint8_t *Tensor::mutable_data<uint8_t>(const PlaceType &place);
template int8_t *Tensor::mutable_data<int8_t>(const PlaceType &place);
template paddle::platform::float16 *
Tensor::mutable_data<paddle::platform::float16>(const PlaceType &place);
template paddle::platform::bfloat16 *
Tensor::mutable_data<paddle::platform::bfloat16>(const PlaceType &place);
template paddle::platform::complex128 *
Tensor::mutable_data<paddle::platform::complex128>(const PlaceType &place);
template paddle::platform::complex64 *
Tensor::mutable_data<paddle::platform::complex64>(const PlaceType &place);
template int16_t *Tensor::mutable_data<int16_t>(const PlaceType &place);
template bool *Tensor::mutable_data<bool>(const PlaceType &place);
std::vector<int> Tensor::shape() const {
GET_CASTED_TENSOR
return framework::vectorize<int>(tensor->dims());
}
const PlaceType &Tensor::place() const {
GET_CASTED_TENSOR;
if (platform::is_cpu_place(tensor->place())) {
place_ = PlaceType::kCPU;
} else if (platform::is_gpu_place(tensor->place())) {
place_ = PlaceType::kGPU;
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Current Tensor hold unsupported Place Type, Please Init it"
"using Tensor::mutable_data<T>(PaddlePlace) which T is"
"either Place::kCPU or Place::kGPU"));
}
return place_;
}
Tensor Tensor::cast(const DataType &target_type) {
GET_CASTED_TENSOR;
Tensor rlt = Tensor(place());
rlt.reshape(this->shape());
auto rlt_tensor_ = static_cast<framework::LoDTensor *>(rlt.tensor_.get());
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto ctx = pool.Get(tensor->place());
auto src_type = tensor->type();
auto dst_type =
framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType(target_type);
switch (src_type) {
case framework::proto::VarType::FP16:
framework::VisitDataType(
dst_type, CastDataType<platform::float16>(*tensor, rlt_tensor_, ctx));
break;
case framework::proto::VarType::BF16:
framework::VisitDataType(dst_type, CastDataType<platform::bfloat16>(
*tensor, rlt_tensor_, ctx));
break;
case framework::proto::VarType::FP32:
framework::VisitDataType(dst_type,
CastDataType<float>(*tensor, rlt_tensor_, ctx));
break;
case framework::proto::VarType::FP64:
framework::VisitDataType(dst_type,
CastDataType<double>(*tensor, rlt_tensor_, ctx));
break;
case framework::proto::VarType::INT32:
framework::VisitDataType(dst_type,
CastDataType<int>(*tensor, rlt_tensor_, ctx));
break;
case framework::proto::VarType::INT64:
framework::VisitDataType(
dst_type, CastDataType<int64_t>(*tensor, rlt_tensor_, ctx));
break;
case framework::proto::VarType::BOOL:
framework::VisitDataType(dst_type,
CastDataType<bool>(*tensor, rlt_tensor_, ctx));
break;
case framework::proto::VarType::INT16:
framework::VisitDataType(
dst_type, CastDataType<int16_t>(*tensor, rlt_tensor_, ctx));
break;
case framework::proto::VarType::UINT8:
framework::VisitDataType(
dst_type, CastDataType<uint8_t>(*tensor, rlt_tensor_, ctx));
break;
// TODO(JiabinYang): Support Complex later
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Data type (%s) is not supported when casting data type.",
framework::DataTypeToString(src_type)));
}
return rlt;
}
int64_t Tensor::size() const {
GET_CASTED_TENSOR;
return tensor->numel();
}
namespace framework {
void CustomTensorUtils::ShareDataTo(const paddle::Tensor &src, void *dst) {
static_cast<framework::LoDTensor *>(dst)->ShareDataWith(
*static_cast<framework::LoDTensor *>(src.tensor_.get()));
}
void CustomTensorUtils::ShareDataFrom(const void *src,
const paddle::Tensor &dst) {
if (!dst.tensor_) {
dst.tensor_ = std::make_shared<framework::LoDTensor>();
}
auto *tensor = static_cast<framework::LoDTensor *>(dst.tensor_.get());
tensor->ShareDataWith(*static_cast<const framework::LoDTensor *>(src));
}
} // namespace framework
} // namespace paddle
......@@ -320,11 +320,17 @@ message(STATUS "branch: ${PADDLE_BRANCH}")
configure_file(commit.h.in commit.h)
set(FLUID_FRAMEWORK_MODULES proto_desc memory lod_tensor executor data_feed_proto layer)
cc_library(custom_tensor SRCS ../extension/src/tensor.cc DEPS lod_tensor)
cc_library(op_meta_info SRCS ../extension/src/op_meta_info.cc DEPS custom_tensor)
cc_library(custom_operator SRCS custom_operator.cc DEPS operator op_registry device_context dynamic_loader custom_tensor op_meta_info)
cc_test(custom_tensor_test SRCS custom_tensor_test.cc DEPS custom_tensor glog)
set(FLUID_FRAMEWORK_MODULES proto_desc memory lod_tensor executor data_feed_proto layer dynamic_loader custom_operator)
cc_library(paddle_framework DEPS ${FLUID_FRAMEWORK_MODULES})
cc_library(paddle_framework_shared
SHARED SRCS executor.cc operator.cc
SHARED SRCS executor.cc operator.cc custom_operator.cc ../extension/src/tensor.cc
../extension/src/op_meta_info.cc
${CMAKE_CURRENT_SOURCE_DIR}/c/c_api.cc
${CMAKE_SOURCE_DIR}/paddle/fluid/imperative/layer.cc
DEPS ${FLUID_FRAMEWORK_MODULES})
......
此差异已折叠。
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/extension/include/op_meta_info.h"
namespace paddle {
namespace framework {
// Load custom op api: register op after user compiled
void LoadOpMetaInfoAndRegisterOp(const std::string& dso_name);
// Register custom op api: register op directly
void RegisterOperatorWithMetaInfoMap(
const paddle::OpMetaInfoMap& op_meta_info_map);
} // namespace framework
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "glog/logging.h"
#include "gtest/gtest.h"
#include "paddle/fluid/extension/include/all.h"
#include "paddle/fluid/framework/custom_tensor_utils.h"
#include "paddle/fluid/framework/lod_tensor.h"
template <typename T>
paddle::Tensor InitCPUTensorForTest() {
std::vector<int> tensor_shape{5, 5};
auto t1 = paddle::Tensor(paddle::PlaceType::kCPU);
t1.reshape(tensor_shape);
auto* p_data_ptr = t1.mutable_data<T>(paddle::PlaceType::kCPU);
for (int64_t i = 0; i < t1.size(); i++) {
p_data_ptr[i] = 5;
}
return t1;
}
template <typename T>
void TestCopyTensor() {
auto t1 = InitCPUTensorForTest<T>();
auto t1_cpu_cp = t1.template copy_to<T>(paddle::PlaceType::kCPU);
CHECK((paddle::PlaceType::kCPU == t1_cpu_cp.place()));
for (int64_t i = 0; i < t1.size(); i++) {
CHECK_EQ(t1_cpu_cp.template data<T>()[i], 5);
}
#ifdef PADDLE_WITH_CUDA
VLOG(2) << "Do GPU copy test";
auto t1_gpu_cp = t1_cpu_cp.template copy_to<T>(paddle::PlaceType::kGPU);
CHECK((paddle::PlaceType::kGPU == t1_gpu_cp.place()));
auto t1_gpu_cp_cp = t1_gpu_cp.template copy_to<T>(paddle::PlaceType::kGPU);
CHECK((paddle::PlaceType::kGPU == t1_gpu_cp_cp.place()));
auto t1_gpu_cp_cp_cpu =
t1_gpu_cp.template copy_to<T>(paddle::PlaceType::kCPU);
CHECK((paddle::PlaceType::kCPU == t1_gpu_cp_cp_cpu.place()));
for (int64_t i = 0; i < t1.size(); i++) {
CHECK_EQ(t1_gpu_cp_cp_cpu.template data<T>()[i], 5);
}
#endif
}
void TestAPIPlace() {
std::vector<int> tensor_shape = {5, 5};
#ifdef PADDLE_WITH_CUDA
auto t1 = paddle::Tensor(paddle::PlaceType::kGPU);
t1.reshape(tensor_shape);
t1.mutable_data<float>();
CHECK((paddle::PlaceType::kGPU == t1.place()));
#endif
auto t2 = paddle::Tensor(paddle::PlaceType::kCPU);
t2.reshape(tensor_shape);
t2.mutable_data<float>();
CHECK((paddle::PlaceType::kCPU == t2.place()));
}
void TestAPISizeAndShape() {
std::vector<int> tensor_shape = {5, 5};
auto t1 = paddle::Tensor(paddle::PlaceType::kCPU);
t1.reshape(tensor_shape);
CHECK_EQ(t1.size(), 25);
CHECK(t1.shape() == tensor_shape);
}
template <typename T>
paddle::DataType TestDtype() {
std::vector<int> tensor_shape = {5, 5};
auto t1 = paddle::Tensor(paddle::PlaceType::kCPU);
t1.reshape(tensor_shape);
t1.template mutable_data<T>();
return t1.type();
}
template <typename T>
void TestCast(paddle::DataType data_type) {
std::vector<int> tensor_shape = {5, 5};
auto t1 = paddle::Tensor(paddle::PlaceType::kCPU);
t1.reshape(tensor_shape);
t1.template mutable_data<T>();
auto t2 = t1.cast(data_type);
CHECK_EQ(t2.type(), data_type);
}
void GroupTestCopy() {
VLOG(2) << "Float cpu-cpu-gpu-gpu-cpu";
TestCopyTensor<float>();
VLOG(2) << "Double cpu-cpu-gpu-gpu-cpu";
TestCopyTensor<double>();
// TODO(JiabinYang): Support these test later
// VLOG(2) << "Fp16 cpu-cpu-gpu-gpu-cpu";
// TestCopyTensor<paddle::platform::float16>();
// VLOG(2) << "BF16 cpu-cpu-gpu-gpu-cpu";
// TestCopyTensor<paddle::platform::bfloat16>();
// VLOG(2) << "complex128 cpu-cpu-gpu-gpu-cpu";
// TestCopyTensor<paddle::platform::complex128>();
// VLOG(2) << "complex64 cpu-cpu-gpu-gpu-cpu";
// TestCopyTensor<paddle::platform::complex64>();
// VLOG(2) << "int cpu-cpu-gpu-gpu-cpu";
TestCopyTensor<int>();
VLOG(2) << "int64 cpu-cpu-gpu-gpu-cpu";
TestCopyTensor<int64_t>();
VLOG(2) << "int16 cpu-cpu-gpu-gpu-cpu";
TestCopyTensor<int16_t>();
VLOG(2) << "int8 cpu-cpu-gpu-gpu-cpu";
TestCopyTensor<int8_t>();
VLOG(2) << "uint8 cpu-cpu-gpu-gpu-cpu";
TestCopyTensor<uint8_t>();
}
void GroupTestCast() {
VLOG(2) << "int cast";
TestCast<int>(paddle::DataType::FLOAT32);
VLOG(2) << "int32 cast";
TestCast<int32_t>(paddle::DataType::FLOAT32);
VLOG(2) << "int64 cast";
TestCast<int64_t>(paddle::DataType::FLOAT32);
VLOG(2) << "double cast";
TestCast<double>(paddle::DataType::FLOAT32);
VLOG(2) << "bfloat16 cast";
TestCast<paddle::platform::bfloat16>(paddle::DataType::FLOAT32);
VLOG(2) << "float16 cast";
TestCast<paddle::platform::float16>(paddle::DataType::FLOAT32);
VLOG(2) << "bool cast";
TestCast<bool>(paddle::DataType::FLOAT32);
VLOG(2) << "uint8 cast";
TestCast<uint8_t>(paddle::DataType::FLOAT32);
VLOG(2) << "float cast";
TestCast<float>(paddle::DataType::FLOAT32);
}
void GroupTestDtype() {
CHECK(TestDtype<float>() == paddle::DataType::FLOAT32);
CHECK(TestDtype<double>() == paddle::DataType::FLOAT64);
CHECK(TestDtype<paddle::platform::float16>() == paddle::DataType::FLOAT16);
CHECK(TestDtype<paddle::platform::bfloat16>() == paddle::DataType::BFLOAT16);
CHECK(TestDtype<paddle::platform::complex128>() ==
paddle::DataType::COMPLEX128);
CHECK(TestDtype<paddle::platform::complex64>() ==
paddle::DataType::COMPLEX64);
CHECK(TestDtype<int>() == paddle::DataType::INT32);
CHECK(TestDtype<int64_t>() == paddle::DataType::INT64);
CHECK(TestDtype<int16_t>() == paddle::DataType::INT16);
CHECK(TestDtype<int8_t>() == paddle::DataType::INT8);
CHECK(TestDtype<uint8_t>() == paddle::DataType::UINT8);
}
void GroupTestDtypeConvert() {
// enum -> proto
CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType(
paddle::DataType::COMPLEX128) ==
paddle::framework::proto::VarType::COMPLEX128);
CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType(
paddle::DataType::COMPLEX64) ==
paddle::framework::proto::VarType::COMPLEX64);
CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType(
paddle::DataType::FLOAT64) ==
paddle::framework::proto::VarType::FP64);
CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType(
paddle::DataType::FLOAT32) ==
paddle::framework::proto::VarType::FP32);
CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType(
paddle::DataType::FLOAT16) ==
paddle::framework::proto::VarType::FP16);
CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType(
paddle::DataType::BFLOAT16) ==
paddle::framework::proto::VarType::BF16);
CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType(
paddle::DataType::UINT8) ==
paddle::framework::proto::VarType::UINT8);
CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType(
paddle::DataType::INT8) == paddle::framework::proto::VarType::INT8);
CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType(
paddle::DataType::INT32) ==
paddle::framework::proto::VarType::INT32);
CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType(
paddle::DataType::INT64) ==
paddle::framework::proto::VarType::INT64);
CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType(
paddle::DataType::INT16) ==
paddle::framework::proto::VarType::INT16);
CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType(
paddle::DataType::BOOL) == paddle::framework::proto::VarType::BOOL);
// proto -> enum
CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType(
paddle::framework::proto::VarType::COMPLEX128) ==
paddle::DataType::COMPLEX128);
CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType(
paddle::framework::proto::VarType::COMPLEX64) ==
paddle::DataType::COMPLEX64);
CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType(
paddle::framework::proto::VarType::FP64) ==
paddle::DataType::FLOAT64);
CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType(
paddle::framework::proto::VarType::FP32) ==
paddle::DataType::FLOAT32);
CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType(
paddle::framework::proto::VarType::FP16) ==
paddle::DataType::FLOAT16);
CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType(
paddle::framework::proto::VarType::BF16) ==
paddle::DataType::BFLOAT16);
CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType(
paddle::framework::proto::VarType::INT64) ==
paddle::DataType::INT64);
CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType(
paddle::framework::proto::VarType::INT32) ==
paddle::DataType::INT32);
CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType(
paddle::framework::proto::VarType::INT8) == paddle::DataType::INT8);
CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType(
paddle::framework::proto::VarType::UINT8) ==
paddle::DataType::UINT8);
CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType(
paddle::framework::proto::VarType::INT16) ==
paddle::DataType::INT16);
CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType(
paddle::framework::proto::VarType::BOOL) == paddle::DataType::BOOL);
}
TEST(CustomTensor, copyTest) {
VLOG(2) << "TestCopy";
GroupTestCopy();
VLOG(2) << "TestDtype";
GroupTestDtype();
VLOG(2) << "TestShape";
TestAPISizeAndShape();
VLOG(2) << "TestPlace";
TestAPIPlace();
VLOG(2) << "TestCast";
GroupTestCast();
VLOG(2) << "TestDtypeConvert";
GroupTestDtypeConvert();
}
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include "paddle/fluid/extension/include/tensor.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/platform/gpu_info.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace framework {
class CustomTensorUtils {
public:
/// \brief Share data TO another tensor.
/// Use this to pass tensor from op to op
/// \return void.
static void ShareDataTo(const paddle::Tensor& src, void* dst);
/// \brief Share data FROM another tensor.
/// Use this to pass tensor from op to op
/// \return void.
static void ShareDataFrom(const void* src, const Tensor& dst);
static framework::proto::VarType::Type ConvertEnumDTypeToInnerDType(
const paddle::DataType& dtype) {
switch (dtype) {
case paddle::DataType::COMPLEX128:
return framework::proto::VarType::COMPLEX128;
case paddle::DataType::COMPLEX64:
return framework::proto::VarType::COMPLEX64;
case paddle::DataType::FLOAT64:
return framework::proto::VarType::FP64;
case paddle::DataType::FLOAT32:
return framework::proto::VarType::FP32;
case paddle::DataType::FLOAT16:
return framework::proto::VarType::FP16;
case paddle::DataType::BFLOAT16:
return framework::proto::VarType::BF16;
case paddle::DataType::UINT8:
return framework::proto::VarType::UINT8;
case paddle::DataType::INT8:
return framework::proto::VarType::INT8;
case paddle::DataType::INT32:
return framework::proto::VarType::INT32;
case paddle::DataType::INT64:
return framework::proto::VarType::INT64;
case paddle::DataType::INT16:
return framework::proto::VarType::INT16;
case paddle::DataType::BOOL:
return framework::proto::VarType::BOOL;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported data type code(%d) when casting enum data type into "
"paddle data type.",
static_cast<int>(dtype)));
}
}
static paddle::DataType ConvertInnerDTypeToEnumDType(
const framework::proto::VarType::Type& dtype) {
switch (dtype) {
case framework::proto::VarType::COMPLEX128:
return paddle::DataType::COMPLEX128;
case framework::proto::VarType::COMPLEX64:
return paddle::DataType::COMPLEX64;
case framework::proto::VarType::FP64:
return paddle::DataType::FLOAT64;
case framework::proto::VarType::FP32:
return paddle::DataType::FLOAT32;
case framework::proto::VarType::FP16:
return paddle::DataType::FLOAT16;
case framework::proto::VarType::BF16:
return paddle::DataType::BFLOAT16;
case framework::proto::VarType::INT64:
return paddle::DataType::INT64;
case framework::proto::VarType::INT32:
return paddle::DataType::INT32;
case framework::proto::VarType::INT8:
return paddle::DataType::INT8;
case framework::proto::VarType::UINT8:
return paddle::DataType::UINT8;
case framework::proto::VarType::INT16:
return paddle::DataType::INT16;
case framework::proto::VarType::BOOL:
return paddle::DataType::BOOL;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported data type `%s` when casting paddle data type into "
"enum data type.",
DataTypeToString(dtype)));
}
}
// PaddlePlace <-> platform::Place
static platform::Place ConvertEnumPlaceToInnerPlace(const PlaceType& pc) {
if (pc == PlaceType::kCPU) {
return platform::Place(platform::CPUPlace());
} else if (pc == PlaceType::kGPU) {
#ifdef PADDLE_WITH_CUDA
return platform::Place(
platform::CUDAPlace(platform::GetCurrentDeviceId()));
#endif
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported place type code(%d) when "
"casting enum place to paddle place.",
static_cast<int>(pc)));
}
return platform::Place();
}
static PlaceType ConvertInnerPlaceToEnumPlace(const platform::Place& pc) {
if (platform::is_cpu_place(pc)) {
return PlaceType::kCPU;
} else if (platform::is_gpu_place(pc)) {
#ifdef PADDLE_WITH_CUDA
return PlaceType::kGPU;
#endif
} else {
PADDLE_THROW(
platform::errors::Unimplemented("Unsupported place type `%s` when "
"casting paddle place to enum place.",
pc));
}
return PlaceType::kUNK;
}
};
} // namespace framework
} // namespace paddle
......@@ -87,6 +87,10 @@ std::string DataTypeToString(const proto::VarType::Type type) {
if (it != gDataTypeMap().proto_to_str_.end()) {
return it->second;
}
// deal with RAW type
if (type == proto::VarType::RAW) {
return "RAW(runtime decided type)";
}
PADDLE_THROW(platform::errors::Unimplemented(
"Not support proto::VarType::Type(%d) as tensor type.",
static_cast<int>(type)));
......
......@@ -97,10 +97,10 @@ void TransDataType(const OpKernelType& kernel_type_for_var,
framework::VisitDataType(dst_type, CastDataType<bool>(in, out, ctx));
break;
case proto::VarType::INT16:
framework::VisitDataType(dst_type, CastDataType<bool>(in, out, ctx));
framework::VisitDataType(dst_type, CastDataType<int16_t>(in, out, ctx));
break;
case proto::VarType::UINT8:
framework::VisitDataType(dst_type, CastDataType<bool>(in, out, ctx));
framework::VisitDataType(dst_type, CastDataType<uint8_t>(in, out, ctx));
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/extension/include/op_meta_info.h"
namespace paddle {
namespace framework {
class OpMetaInfoHelper {
public:
static const std::string& GetOpName(const paddle::OpMetaInfo& info) {
return info.name_;
}
static const std::vector<std::string>& GetInputs(
const paddle::OpMetaInfo& info) {
return info.inputs_;
}
static const std::vector<std::string>& GetOutputs(
const paddle::OpMetaInfo& info) {
return info.outputs_;
}
static const std::vector<std::string>& GetAttrs(
const paddle::OpMetaInfo& info) {
return info.attrs_;
}
static const KernelFunc& GetKernelFn(const paddle::OpMetaInfo& info) {
return info.kernel_fn_;
}
static const InferShapeFunc& GetInferShapeFn(const paddle::OpMetaInfo& info) {
return info.infer_shape_fn_;
}
static const InferDtypeFunc& GetInferDtypeFn(const paddle::OpMetaInfo& info) {
return info.infer_dtype_fn_;
}
};
} // namespace framework
} // namespace paddle
set(PYBIND_DEPS pybind python proto_desc memory executor fleet_wrapper box_wrapper prune
feed_fetch_method pass_builder parallel_executor profiler layer tracer engine scope_pool
analysis_predictor imperative_profiler imperative_flag save_load_util dlpack_tensor device_context
gloo_wrapper infer_io_utils heter_wrapper generator op_version_registry ps_gpu_wrapper)
gloo_wrapper infer_io_utils heter_wrapper generator op_version_registry ps_gpu_wrapper custom_operator)
if (WITH_GPU)
set(PYBIND_DEPS ${PYBIND_DEPS} dynload_cuda)
......
......@@ -24,6 +24,7 @@ limitations under the License. */
#include <utility>
#include <vector>
#include "paddle/fluid/framework/custom_operator.h"
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/feed_fetch_method.h"
......@@ -397,7 +398,7 @@ PYBIND11_MODULE(core_noavx, m) {
PyCapsule_GetPointer(dltensor->ptr(), "dltensor"));
PyCapsule_SetName(dltensor->ptr(), "used_dltensor");
DLTensor dl = dmt->dl_tensor;
Tensor tensor;
framework::Tensor tensor;
if (dl.ctx.device_type == kDLCPU) {
paddle::framework::TensorFromDLPack(dl, &tensor);
......@@ -535,77 +536,80 @@ PYBIND11_MODULE(core_noavx, m) {
BindImperative(&m);
py::class_<Tensor>(m, "Tensor", py::buffer_protocol())
.def("__array__", [](Tensor &self) { return TensorToPyArray(self); })
py::class_<framework::Tensor>(m, "Tensor", py::buffer_protocol())
.def("__array__",
[](framework::Tensor &self) { return TensorToPyArray(self); })
.def("_is_initialized",
[](const Tensor &self) { return self.IsInitialized(); })
[](const framework::Tensor &self) { return self.IsInitialized(); })
.def("_get_dims",
[](const Tensor &self) { return vectorize(self.dims()); })
[](const framework::Tensor &self) { return vectorize(self.dims()); })
.def("_set_dims",
[](Tensor &self, const std::vector<int64_t> &dim) {
[](framework::Tensor &self, const std::vector<int64_t> &dim) {
self.Resize(make_ddim(dim));
})
.def("_set_layout",
[](Tensor &self, const std::string &layout) {
[](framework::Tensor &self, const std::string &layout) {
self.set_layout(StringToDataLayout(layout));
})
.def("_alloc_float",
[](Tensor &self, paddle::platform::CUDAPlace &place) {
[](framework::Tensor &self, paddle::platform::CUDAPlace &place) {
self.mutable_data<float>(place);
})
.def("_alloc_float",
[](Tensor &self, paddle::platform::XPUPlace &place) {
[](framework::Tensor &self, paddle::platform::XPUPlace &place) {
self.mutable_data<float>(place);
})
.def("_alloc_float",
[](Tensor &self, paddle::platform::CPUPlace &place) {
[](framework::Tensor &self, paddle::platform::CPUPlace &place) {
self.mutable_data<float>(place);
})
.def("_alloc_double",
[](Tensor &self, paddle::platform::CPUPlace &place) {
[](framework::Tensor &self, paddle::platform::CPUPlace &place) {
self.mutable_data<double>(place);
})
.def("_alloc_int",
[](Tensor &self, paddle::platform::CPUPlace &place) {
[](framework::Tensor &self, paddle::platform::CPUPlace &place) {
self.mutable_data<int>(place);
})
.def("_alloc_int",
[](Tensor &self, paddle::platform::XPUPlace &place) {
[](framework::Tensor &self, paddle::platform::XPUPlace &place) {
self.mutable_data<int>(place);
})
.def("_alloc_int",
[](Tensor &self, paddle::platform::CUDAPlace &place) {
[](framework::Tensor &self, paddle::platform::CUDAPlace &place) {
self.mutable_data<int>(place);
})
.def("_alloc_int",
[](Tensor &self, paddle::platform::CUDAPinnedPlace &place) {
[](framework::Tensor &self,
paddle::platform::CUDAPinnedPlace &place) {
self.mutable_data<int>(place);
})
.def("_alloc_float",
[](Tensor &self, paddle::platform::CUDAPinnedPlace &place) {
[](framework::Tensor &self,
paddle::platform::CUDAPinnedPlace &place) {
self.mutable_data<float>(place);
})
.def("_mutable_data",
[](Tensor &self, paddle::platform::CPUPlace &place,
[](framework::Tensor &self, paddle::platform::CPUPlace &place,
paddle::framework::proto::VarType::Type type) {
return reinterpret_cast<uintptr_t>(self.mutable_data(place, type));
})
.def("_mutable_data",
[](Tensor &self, paddle::platform::XPUPlace &place,
[](framework::Tensor &self, paddle::platform::XPUPlace &place,
paddle::framework::proto::VarType::Type type) {
return reinterpret_cast<uintptr_t>(self.mutable_data(place, type));
})
.def("_mutable_data",
[](Tensor &self, paddle::platform::CUDAPlace &place,
[](framework::Tensor &self, paddle::platform::CUDAPlace &place,
paddle::framework::proto::VarType::Type type) {
return reinterpret_cast<uintptr_t>(self.mutable_data(place, type));
})
.def("_mutable_data",
[](Tensor &self, paddle::platform::CUDAPinnedPlace &place,
[](framework::Tensor &self, paddle::platform::CUDAPinnedPlace &place,
paddle::framework::proto::VarType::Type type) {
return reinterpret_cast<uintptr_t>(self.mutable_data(place, type));
})
.def("_clear", &Tensor::clear)
.def("_clear", &framework::Tensor::clear)
.def("set", SetTensorFromPyArray<paddle::platform::CPUPlace>,
py::arg("array"), py::arg("place"), py::arg("zero_copy") = false)
.def("set", SetTensorFromPyArray<paddle::platform::XPUPlace>,
......@@ -637,7 +641,9 @@ PYBIND11_MODULE(core_noavx, m) {
t.set(np.ndarray([5, 30]), fluid.CPUPlace())
)DOC")
.def("shape", [](Tensor &self) { return vectorize(self.dims()); }, R"DOC(
.def("shape",
[](framework::Tensor &self) { return vectorize(self.dims()); },
R"DOC(
Return the shape of LoDTensor.
Returns:
......@@ -655,7 +661,7 @@ PYBIND11_MODULE(core_noavx, m) {
print(t.shape()) # [5, 30]
)DOC")
.def("_to_dlpack",
[](Tensor &self) {
[](framework::Tensor &self) {
DLPackTensor dlpack_tensor(self, 1);
DLManagedTensor *dmt =
dlpack_tensor.ToCudfCompatibleDLManagedTensor();
......@@ -680,20 +686,22 @@ PYBIND11_MODULE(core_noavx, m) {
.def("_get_float_element", TensorGetElement<float>)
.def("_set_double_element", TensorSetElement<double>)
.def("_get_double_element", TensorGetElement<double>)
.def("_place", [](Tensor &self) { return self.place(); })
.def("_dtype", [](Tensor &self) { return self.type(); })
.def("_place", [](framework::Tensor &self) { return self.place(); })
.def("_dtype", [](framework::Tensor &self) { return self.type(); })
.def("_layout",
[](Tensor &self) { return DataLayoutToString(self.layout()); })
.def("_share_data_with", &Tensor::ShareDataWith)
[](framework::Tensor &self) {
return DataLayoutToString(self.layout());
})
.def("_share_data_with", &framework::Tensor::ShareDataWith)
.def("__getitem__", PySliceTensor, py::return_value_policy::reference)
.def("__str__", [](const Tensor &self) {
.def("__str__", [](const framework::Tensor &self) {
std::stringstream ostr;
ostr << self;
return ostr.str();
});
// TODO(cql): add reference: en_user_guide_lod_tensor
py::class_<LoDTensor, Tensor>(m, "LoDTensor", R"DOC(
py::class_<LoDTensor, framework::Tensor>(m, "LoDTensor", R"DOC(
LoDTensor is a Tensor with optional LoD (Level of Details) information,
it can be used for variable-length sequences,
see :ref:`user_guide_lod_tensor` for details.
......@@ -777,7 +785,8 @@ PYBIND11_MODULE(core_noavx, m) {
t = fluid.LoDTensor()
)DOC")
.def("__array__", [](Tensor &self) { return TensorToPyArray(self); })
.def("__array__",
[](framework::Tensor &self) { return TensorToPyArray(self); })
.def("__init__",
[](LoDTensor &instance, const std::vector<std::vector<size_t>>
&recursive_sequence_lengths) {
......@@ -1735,6 +1744,8 @@ All parameter, weight, gradient are variables in Paddle.
m.def("init_gflags", framework::InitGflags);
m.def("init_glog", framework::InitGLOG);
m.def("load_op_library", framework::LoadOpLib);
m.def("load_op_meta_info_and_register_op",
framework::LoadOpMetaInfoAndRegisterOp);
m.def("init_devices", []() { framework::InitDevices(); });
m.def("is_compiled_with_cuda", IsCompiledWithCUDA);
......
......@@ -30,3 +30,6 @@ endforeach()
set_tests_properties(test_custom_op_with_setup PROPERTIES TIMEOUT 180)
set_tests_properties(test_jit_load PROPERTIES TIMEOUT 180)
set_tests_properties(test_setup_install PROPERTIES TIMEOUT 180)
set_tests_properties(test_simple_custom_op_setup PROPERTIES TIMEOUT 250)
set_tests_properties(test_simple_custom_op_jit PROPERTIES TIMEOUT 180)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <vector>
#include "paddle/extension.h"
template <typename data_t>
void relu_cpu_forward_kernel(const data_t* x_data,
data_t* out_data,
int64_t x_numel) {
for (int i = 0; i < x_numel; ++i) {
out_data[i] = std::max(static_cast<data_t>(0.), x_data[i]);
}
}
template <typename data_t>
void relu_cpu_backward_kernel(const data_t* grad_out_data,
const data_t* out_data,
data_t* grad_x_data,
int64_t out_numel) {
for (int i = 0; i < out_numel; ++i) {
grad_x_data[i] =
grad_out_data[i] * (out_data[i] > static_cast<data_t>(0) ? 1. : 0.);
}
}
std::vector<paddle::Tensor> relu_cpu_forward(const paddle::Tensor& x) {
auto out = paddle::Tensor(paddle::PlaceType::kCPU);
out.reshape(x.shape());
PD_DISPATCH_FLOATING_TYPES(
x.type(), "relu_cpu_forward", ([&] {
relu_cpu_forward_kernel<data_t>(
x.data<data_t>(), out.mutable_data<data_t>(x.place()), x.size());
}));
return {out};
}
std::vector<paddle::Tensor> relu_cpu_backward(const paddle::Tensor& x,
const paddle::Tensor& out,
const paddle::Tensor& grad_out) {
auto grad_x = paddle::Tensor(paddle::PlaceType::kCPU);
grad_x.reshape(x.shape());
PD_DISPATCH_FLOATING_TYPES(out.type(), "relu_cpu_backward", ([&] {
relu_cpu_backward_kernel<data_t>(
grad_out.data<data_t>(),
out.data<data_t>(),
grad_x.mutable_data<data_t>(x.place()),
out.size());
}));
return {grad_x};
}
std::vector<paddle::Tensor> relu_cuda_forward(const paddle::Tensor& x);
std::vector<paddle::Tensor> relu_cuda_backward(const paddle::Tensor& x,
const paddle::Tensor& out,
const paddle::Tensor& grad_out);
std::vector<paddle::Tensor> ReluForward(const paddle::Tensor& x) {
// TODO(chenweihang): Check Input
if (x.place() == paddle::PlaceType::kCPU) {
return relu_cpu_forward(x);
} else if (x.place() == paddle::PlaceType::kGPU) {
return relu_cuda_forward(x);
} else {
throw std::runtime_error("Not implemented.");
}
}
std::vector<paddle::Tensor> ReluBackward(const paddle::Tensor& x,
const paddle::Tensor& out,
const paddle::Tensor& grad_out) {
// TODO(chenweihang): Check Input
if (x.place() == paddle::PlaceType::kCPU) {
return relu_cpu_backward(x, out, grad_out);
} else if (x.place() == paddle::PlaceType::kGPU) {
return relu_cuda_backward(x, out, grad_out);
} else {
throw std::runtime_error("Not implemented.");
}
}
std::vector<std::vector<int64_t>> ReluInferShape(std::vector<int64_t> x_shape) {
return {x_shape};
}
std::vector<paddle::DataType> ReluInferDType(paddle::DataType x_dtype) {
return {x_dtype};
}
PD_BUILD_OPERATOR("relu2")
.Inputs({"X"})
.Outputs({"Out"})
.SetKernelFn(PD_KERNEL(ReluForward))
.SetInferShapeFn(PD_INFER_SHAPE(ReluInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(ReluInferDType))
.SetBackwardOp("relu2_grad")
.Inputs({"X", "Out", paddle::Grad("Out")})
.Outputs({paddle::Grad("X")})
.SetKernelFn(PD_KERNEL(ReluBackward));
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/extension.h"
template <typename data_t>
__global__ void relu_cuda_forward_kernel(const data_t* x,
data_t* y,
const int num) {
int gid = blockIdx.x * blockDim.x + threadIdx.x;
for (int i = gid; i < num; i += blockDim.x * gridDim.x) {
y[i] = max(x[i], static_cast<data_t>(0.));
}
}
template <typename data_t>
__global__ void relu_cuda_backward_kernel(const data_t* dy,
const data_t* y,
data_t* dx,
const int num) {
int gid = blockIdx.x * blockDim.x + threadIdx.x;
for (int i = gid; i < num; i += blockDim.x * gridDim.x) {
dx[i] = dy[i] * (y[i] > 0 ? 1. : 0.);
}
}
std::vector<paddle::Tensor> relu_cuda_forward(const paddle::Tensor& x) {
auto out = paddle::Tensor(paddle::PlaceType::kGPU);
out.reshape(x.shape());
int numel = x.size();
int block = 512;
int grid = (numel + block - 1) / block;
PD_DISPATCH_FLOATING_TYPES(
x.type(), "relu_cuda_forward_kernel", ([&] {
relu_cuda_forward_kernel<data_t><<<grid, block>>>(
x.data<data_t>(), out.mutable_data<data_t>(x.place()), numel);
}));
return {out};
}
std::vector<paddle::Tensor> relu_cuda_backward(const paddle::Tensor& x,
const paddle::Tensor& out,
const paddle::Tensor& grad_out) {
auto grad_x = paddle::Tensor(paddle::PlaceType::kGPU);
grad_x.reshape(x.shape());
int numel = out.size();
int block = 512;
int grid = (numel + block - 1) / block;
PD_DISPATCH_FLOATING_TYPES(
out.type(), "relu_cuda_backward_kernel", ([&] {
relu_cuda_backward_kernel<data_t><<<grid, block>>>(
grad_out.data<data_t>(),
out.data<data_t>(),
grad_x.mutable_data<data_t>(x.place()),
numel);
}));
return {grad_x};
}
......@@ -15,6 +15,10 @@ import os
from utils import paddle_includes, extra_compile_args
from paddle.utils.cpp_extension import CppExtension, CUDAExtension, BuildExtension, setup
from paddle.utils.cpp_extension.extension_utils import use_new_custom_op_load_method
# switch to old custom op method
use_new_custom_op_load_method(False)
file_dir = os.path.dirname(os.path.abspath(__file__))
......
......@@ -15,6 +15,10 @@ import os
from utils import paddle_includes, extra_compile_args
from paddle.utils.cpp_extension import CUDAExtension, setup
from paddle.utils.cpp_extension.extension_utils import use_new_custom_op_load_method
# switch to old custom op method
use_new_custom_op_load_method(False)
setup(
name='custom_relu2',
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from utils import paddle_includes, extra_compile_args
from paddle.utils.cpp_extension import CUDAExtension, setup
setup(
name='simple_setup_relu2',
ext_modules=[
CUDAExtension(
name='simple_setup_relu2',
sources=['relu_op_simple.cc', 'relu_op_simple.cu'],
include_dirs=paddle_includes,
extra_compile_args=extra_compile_args)
])
......@@ -16,6 +16,10 @@ import os
import unittest
from test_custom_op import CustomOpTest, load_so
from paddle.utils.cpp_extension.extension_utils import run_cmd
from paddle.utils.cpp_extension.extension_utils import use_new_custom_op_load_method
# switch to old custom op method
use_new_custom_op_load_method(False)
def compile_so():
......
......@@ -18,6 +18,10 @@ import paddle
import numpy as np
from paddle.utils.cpp_extension import load
from utils import paddle_includes, extra_compile_args
from paddle.utils.cpp_extension.extension_utils import use_new_custom_op_load_method
# switch to old custom op method
use_new_custom_op_load_method(False)
# Compile and load custom op Just-In-Time.
relu2 = load(
......
......@@ -20,6 +20,10 @@ import paddle
import subprocess
import numpy as np
from paddle.utils.cpp_extension.extension_utils import run_cmd
from paddle.utils.cpp_extension.extension_utils import use_new_custom_op_load_method
# switch to old custom op method
use_new_custom_op_load_method(False)
class TestSetUpInstall(unittest.TestCase):
......@@ -38,7 +42,8 @@ class TestSetUpInstall(unittest.TestCase):
custom_egg_path = [
x for x in os.listdir(site_dir) if 'custom_relu2' in x
]
assert len(custom_egg_path) == 1
assert len(custom_egg_path) == 1, "Matched egg number is %d." % len(
custom_egg_path)
sys.path.append(os.path.join(site_dir, custom_egg_path[0]))
def test_api(self):
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
import paddle
import numpy as np
from paddle.utils.cpp_extension import load
from utils import paddle_includes, extra_compile_args
from test_simple_custom_op_setup import relu2_dynamic, relu2_static
# Compile and load custom op Just-In-Time.
simple_relu2 = load(
name='simple_jit_relu2',
sources=['relu_op_simple.cc', 'relu_op_simple.cu'],
extra_include_paths=paddle_includes, # add for Coverage CI
extra_cflags=extra_compile_args) # add for Coverage CI
class TestJITLoad(unittest.TestCase):
def setUp(self):
self.custom_op = simple_relu2
self.dtypes = ['float32', 'float64']
self.devices = ['cpu', 'gpu']
def test_static(self):
for device in self.devices:
for dtype in self.dtypes:
x = np.random.uniform(-1, 1, [4, 8]).astype(dtype)
out = relu2_static(self.custom_op, device, dtype, x)
pd_out = relu2_static(self.custom_op, device, dtype, x, False)
self.assertTrue(
np.array_equal(out, pd_out),
"custom op out: {},\n paddle api out: {}".format(out,
pd_out))
def test_dynamic(self):
for device in self.devices:
for dtype in self.dtypes:
x = np.random.uniform(-1, 1, [4, 8]).astype(dtype)
out, x_grad = relu2_dynamic(self.custom_op, device, dtype, x)
pd_out, pd_x_grad = relu2_dynamic(self.custom_op, device, dtype,
x, False)
self.assertTrue(
np.array_equal(out, pd_out),
"custom op out: {},\n paddle api out: {}".format(out,
pd_out))
self.assertTrue(
np.array_equal(x_grad, pd_x_grad),
"custom op x grad: {},\n paddle api x grad: {}".format(
x_grad, pd_x_grad))
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import site
import unittest
import paddle
import paddle.static as static
import subprocess
import numpy as np
from paddle.utils.cpp_extension.extension_utils import run_cmd
def relu2_dynamic(func, device, dtype, np_x, use_func=True):
paddle.set_device(device)
t = paddle.to_tensor(np_x)
t.stop_gradient = False
out = func(t) if use_func else paddle.nn.functional.relu(t)
out.stop_gradient = False
out.backward()
return out.numpy(), t.grad
def relu2_static(func, device, dtype, np_x, use_func=True):
paddle.enable_static()
paddle.set_device(device)
with static.scope_guard(static.Scope()):
with static.program_guard(static.Program()):
x = static.data(name='X', shape=[None, 8], dtype=dtype)
x.stop_gradient = False
out = func(x) if use_func else paddle.nn.functional.relu(x)
static.append_backward(out)
exe = static.Executor()
exe.run(static.default_startup_program())
# in static mode, x data has been covered by out
out_v = exe.run(static.default_main_program(),
feed={'X': np_x},
fetch_list=[out.name])
return out_v
def relu2_static_pe(func, device, dtype, np_x, use_func=True):
paddle.enable_static()
paddle.set_device(device)
places = static.cpu_places() if device is 'cpu' else static.cuda_places()
with static.scope_guard(static.Scope()):
with static.program_guard(static.Program()):
x = static.data(name='X', shape=[None, 8], dtype=dtype)
x.stop_gradient = False
out = func(x) if use_func else paddle.nn.functional.relu(x)
static.append_backward(out)
exe = static.Executor()
exe.run(static.default_startup_program())
# in static mode, x data has been covered by out
compiled_prog = static.CompiledProgram(static.default_main_program(
)).with_data_parallel(
loss_name=out.name, places=places)
out_v = exe.run(compiled_prog,
feed={'X': np_x},
fetch_list=[out.name])
return out_v
class TestNewCustomOpSetUpInstall(unittest.TestCase):
def setUp(self):
cur_dir = os.path.dirname(os.path.abspath(__file__))
# compile, install the custom op egg into site-packages under background
cmd = 'cd {} && python setup_install_simple.py install'.format(cur_dir)
run_cmd(cmd)
# NOTE(Aurelius84): Normally, it's no need to add following codes for users.
# But we simulate to pip install in current process, so interpreter don't snap
# sys.path has been updated. So we update it manually.
# See: https://stackoverflow.com/questions/56974185/import-runtime-installed-module-using-pip-in-python-3
site_dir = site.getsitepackages()[0]
custom_egg_path = [
x for x in os.listdir(site_dir) if 'simple_setup_relu2' in x
]
assert len(custom_egg_path) == 1, "Matched egg number is %d." % len(
custom_egg_path)
sys.path.append(os.path.join(site_dir, custom_egg_path[0]))
# usage: import the package directly
import simple_setup_relu2
self.custom_op = simple_setup_relu2.relu2
self.dtypes = ['float32', 'float64']
self.devices = ['cpu', 'gpu']
def test_static(self):
for device in self.devices:
for dtype in self.dtypes:
x = np.random.uniform(-1, 1, [4, 8]).astype(dtype)
out = relu2_static(self.custom_op, device, dtype, x)
pd_out = relu2_static(self.custom_op, device, dtype, x, False)
self.assertTrue(
np.array_equal(out, pd_out),
"custom op out: {},\n paddle api out: {}".format(out,
pd_out))
def test_static_pe(self):
for device in self.devices:
for dtype in self.dtypes:
x = np.random.uniform(-1, 1, [4, 8]).astype(dtype)
out = relu2_static_pe(self.custom_op, device, dtype, x)
pd_out = relu2_static_pe(self.custom_op, device, dtype, x,
False)
self.assertTrue(
np.array_equal(out, pd_out),
"custom op out: {},\n paddle api out: {}".format(out,
pd_out))
def test_dynamic(self):
for device in self.devices:
for dtype in self.dtypes:
x = np.random.uniform(-1, 1, [4, 8]).astype(dtype)
out, x_grad = relu2_dynamic(self.custom_op, device, dtype, x)
pd_out, pd_x_grad = relu2_dynamic(self.custom_op, device, dtype,
x, False)
self.assertTrue(
np.array_equal(out, pd_out),
"custom op out: {},\n paddle api out: {}".format(out,
pd_out))
self.assertTrue(
np.array_equal(x_grad, pd_x_grad),
"custom op x grad: {},\n paddle api x grad: {}".format(
x_grad, pd_x_grad))
if __name__ == '__main__':
unittest.main()
......@@ -19,6 +19,7 @@ from .cpp_extension import load, setup
from .extension_utils import parse_op_info
from .extension_utils import get_build_directory
from .extension_utils import load_op_meta_info_and_register_op
from . import cpp_extension
from . import extension_utils
......
......@@ -25,6 +25,7 @@ from setuptools.command.build_ext import build_ext
from .extension_utils import find_cuda_home, normalize_extension_kwargs, add_compile_flag, bootstrap_context
from .extension_utils import is_cuda_file, prepare_unix_cflags, add_std_without_repeat, get_build_directory
from .extension_utils import _import_module_from_library, CustomOpInfo, _write_setup_file, _jit_compile, parse_op_name_from
from .extension_utils import use_new_custom_op_load_method
IS_WINDOWS = os.name == 'nt'
CUDA_HOME = find_cuda_home()
......@@ -132,6 +133,9 @@ class BuildExtension(build_ext, object):
super(BuildExtension, self).__init__(*args, **kwargs)
self.no_python_abi_suffix = kwargs.get("no_python_abi_suffix", True)
self.output_dir = kwargs.get("output_dir", None)
# for compatible two custom op define method
use_new_custom_op_load_method(
kwargs.get("use_new_method", use_new_custom_op_load_method()))
def initialize_options(self):
super(BuildExtension, self).initialize_options()
......
......@@ -29,6 +29,7 @@ from setuptools.command import bdist_egg
from .. import load_op_library
from ...fluid import core
from ...fluid.framework import OpProtoHolder
from ...sysconfig import get_include, get_lib
OS_NAME = platform.system()
......@@ -38,6 +39,20 @@ NVCC_COMPILE_FLAGS = [
'-Xcompiler', '-fPIC', '-w', '--expt-relaxed-constexpr', '-O3', '-DNVCC'
]
USING_NEW_CUSTOM_OP_LOAD_METHOD = True
# NOTE(chenweihang): In order to be compatible with
# the two custom op define method, after removing
# old method, we can remove them together
def use_new_custom_op_load_method(*args):
global USING_NEW_CUSTOM_OP_LOAD_METHOD
if len(args) == 0:
return USING_NEW_CUSTOM_OP_LOAD_METHOD
else:
assert len(args) == 1 and isinstance(args[0], bool)
USING_NEW_CUSTOM_OP_LOAD_METHOD = args[0]
@contextmanager
def bootstrap_context():
......@@ -51,6 +66,15 @@ def bootstrap_context():
bdist_egg.write_stub = origin_write_stub
def load_op_meta_info_and_register_op(lib_filename):
if USING_NEW_CUSTOM_OP_LOAD_METHOD:
core.load_op_meta_info_and_register_op(lib_filename)
else:
print("old branch")
core.load_op_library(lib_filename)
return OpProtoHolder.instance().update_op_proto()
def custom_write_stub(resource, pyfile):
"""
Customized write_stub function to allow us to inject generated python
......@@ -77,7 +101,7 @@ def custom_write_stub(resource, pyfile):
assert os.path.exists(so_path)
# load custom op shared library with abs path
new_custom_op = paddle.utils.load_op_library(so_path)
new_custom_op = paddle.utils.cpp_extension.load_op_meta_info_and_register_op(so_path)
assert len(new_custom_op) == 1
m = inject_ext_module(__name__, new_custom_op[0])
......@@ -90,8 +114,10 @@ def custom_write_stub(resource, pyfile):
_, op_info = CustomOpInfo.instance().last()
so_path = op_info.build_directory
new_custom_op = load_op_library(so_path)
assert len(new_custom_op) == 1
new_custom_op = load_op_meta_info_and_register_op(so_path)
assert len(new_custom_op
) == 1, "The number of loaded costom operators is %d" % len(
new_custom_op)
# NOTE: To avoid importing .so file instead of python file because they have same name,
# we rename .so shared library to another name, see EasyInstallCommand.
......@@ -338,7 +364,7 @@ def parse_op_info(op_name):
from paddle.fluid.framework import OpProtoHolder
if op_name not in OpProtoHolder.instance().op_proto_map:
raise ValueError(
"Please load {} shared library file firstly by `paddle.utils.load_op_library(...)`".
"Please load {} shared library file firstly by `paddle.utils.cpp_extension.load_op_meta_info_and_register_op(...)`".
format(op_name))
op_proto = OpProtoHolder.instance().get_op_proto(op_name)
......@@ -361,7 +387,7 @@ def _import_module_from_library(name, build_directory):
ext_path))
# load custom op_info and kernels from .so shared library
op_names = load_op_library(ext_path)
op_names = load_op_meta_info_and_register_op(ext_path)
assert len(op_names) == 1
# generate Python api in ext_path
......@@ -473,7 +499,8 @@ def _write_setup_file(name, sources, file_path, include_dirs, compile_flags,
extra_link_args={extra_link_args})],
cmdclass={{"build_ext" : BuildExtension.with_options(
output_dir=get_build_directory(),
no_python_abi_suffix=True)
no_python_abi_suffix=True,
use_new_method={use_new_method})
}})""").lstrip()
with_cuda = False
......@@ -486,7 +513,8 @@ def _write_setup_file(name, sources, file_path, include_dirs, compile_flags,
sources=list2str(sources),
include_dirs=list2str(include_dirs),
extra_compile_args=list2str(compile_flags),
extra_link_args=list2str(link_args))
extra_link_args=list2str(link_args),
use_new_method=use_new_custom_op_load_method())
with open(file_path, 'w') as f:
f.write(content)
......@@ -517,7 +545,10 @@ def parse_op_name_from(sources):
"""
def regex(content):
pattern = re.compile(r'REGISTER_OPERATOR\(([^,]+),')
if USING_NEW_CUSTOM_OP_LOAD_METHOD:
pattern = re.compile(r'BUILD_OPERATOR\(([^,]+),')
else:
pattern = re.compile(r'REGISTER_OPERATOR\(([^,]+),')
content = re.sub(r'\s|\t|\n', '', content)
op_name = pattern.findall(content)
......@@ -532,7 +563,9 @@ def parse_op_name_from(sources):
op_names |= regex(content)
# TODO(Aurelius84): Support register more customs op at once
assert len(op_names) == 1
assert len(
op_names) == 1, "The number of registered costom operators is %d" % len(
op_names)
return list(op_names)[0]
......
......@@ -380,6 +380,8 @@ def find_files(pattern, root):
yield os.path.join(dirpath, filename)
headers = (
list(find_files('*.h', '@PADDLE_SOURCE_DIR@/paddle')) +
list(find_files('*.h', '@PADDLE_SOURCE_DIR@/paddle/fluid/extension')) +
list(find_files('*.h', '@PADDLE_SOURCE_DIR@/paddle/fluid/framework')) +
list(find_files('*.h', '@PADDLE_SOURCE_DIR@/paddle/fluid/imperative')) +
list(find_files('*.h', '@PADDLE_SOURCE_DIR@/paddle/fluid/memory')) +
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册