未验证 提交 628451af 编写于 作者: J Jiabin Yang 提交者: GitHub

hide useless headers and add complex support (#31074)

上级 463eae03
...@@ -13,13 +13,14 @@ See the License for the specific language governing permissions and ...@@ -13,13 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle { namespace paddle {
struct complex128;
struct complex64;
struct float16;
struct bfloat16;
enum DataType { enum DataType {
FLOAT32, FLOAT32,
FLOAT64, FLOAT64,
......
...@@ -293,6 +293,11 @@ class OpMetaInfoBuilder { ...@@ -293,6 +293,11 @@ class OpMetaInfoBuilder {
// Call after PD_BUILD_OP(...) // Call after PD_BUILD_OP(...)
void RegisterAllCustomOperator(); void RegisterAllCustomOperator();
// Using this api to load compiled custom operator's dynamic library and
// register Custom
// Operator into it
void LoadCustomOperatorLib(const std::string& dso_name);
/////////////////////// Op register Macro ///////////////////////// /////////////////////// Op register Macro /////////////////////////
#define PD_BUILD_OP(op_name) \ #define PD_BUILD_OP(op_name) \
......
...@@ -25,13 +25,13 @@ class CustomTensorUtils; ...@@ -25,13 +25,13 @@ class CustomTensorUtils;
} // namespace framework } // namespace framework
class Tensor { class Tensor {
public: public:
/// \brief Construct a Tensor on None Place for CustomOp. /// \brief Construct a Tensor on target Place for CustomOp.
/// Generally it's only used for user to create Tensor. /// Generally it's only used for user to create Tensor.
explicit Tensor(const PlaceType& place); explicit Tensor(const PlaceType& place);
/// \brief Reset the shape of the tensor. /// \brief Reset the shape of the tensor.
/// Generally it's only used for the input tensor. /// Generally it's only used for the input tensor.
/// Reshape must be called before calling /// Reshape must be called before calling
/// mutable_data() or copy_from_cpu() /// mutable_data() or copy_to(const PlaceType& place)
/// \param shape The shape to set. /// \param shape The shape to set.
void reshape(const std::vector<int>& shape); void reshape(const std::vector<int>& shape);
...@@ -59,11 +59,11 @@ class Tensor { ...@@ -59,11 +59,11 @@ class Tensor {
/// \brief Copy the host memory to tensor data. /// \brief Copy the host memory to tensor data.
/// It's usually used to set the input tensor data. /// It's usually used to set the input tensor data.
/// \param PlaceType of target place, from which /// \param PlaceType of target place, of which
/// the tensor will copy. /// the tensor will copy to.
template <typename T> template <typename T>
Tensor copy_to(const PlaceType& place); Tensor copy_to(const PlaceType& place) const;
/// \brief Return the shape of the Tensor. /// \brief Return the shape of the Tensor.
std::vector<int> shape() const; std::vector<int> shape() const;
...@@ -84,7 +84,7 @@ class Tensor { ...@@ -84,7 +84,7 @@ class Tensor {
const PlaceType& place() const; const PlaceType& place() const;
/// \brief Cast datatype from one to another /// \brief Cast datatype from one to another
Tensor cast(const DataType& target_type); Tensor cast(const DataType& target_type) const;
private: private:
friend class framework::CustomTensorUtils; friend class framework::CustomTensorUtils;
......
...@@ -109,6 +109,9 @@ void RegisterAllCustomOperator() { ...@@ -109,6 +109,9 @@ void RegisterAllCustomOperator() {
framework::RegisterOperatorWithMetaInfoMap(op_meta_info_map); framework::RegisterOperatorWithMetaInfoMap(op_meta_info_map);
} }
void LoadCustomOperatorLib(const std::string& dso_name) {
paddle::framework::LoadOpMetaInfoAndRegisterOp(dso_name);
}
} // namespace paddle } // namespace paddle
extern "C" { extern "C" {
......
...@@ -17,7 +17,11 @@ limitations under the License. */ ...@@ -17,7 +17,11 @@ limitations under the License. */
#include "paddle/fluid/framework/custom_tensor_utils.h" #include "paddle/fluid/framework/custom_tensor_utils.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/transform.h" #include "paddle/fluid/platform/transform.h"
namespace paddle { namespace paddle {
...@@ -174,7 +178,7 @@ DataType Tensor::type() const { ...@@ -174,7 +178,7 @@ DataType Tensor::type() const {
} }
template <typename T> template <typename T>
Tensor Tensor::copy_to(const PlaceType &target_place) { Tensor Tensor::copy_to(const PlaceType &target_place) const {
GET_CASTED_TENSOR; GET_CASTED_TENSOR;
PADDLE_ENFORCE_GE(tensor->numel(), 0, PADDLE_ENFORCE_GE(tensor->numel(), 0,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
...@@ -208,21 +212,21 @@ Tensor Tensor::copy_to(const PlaceType &target_place) { ...@@ -208,21 +212,21 @@ Tensor Tensor::copy_to(const PlaceType &target_place) {
} }
template Tensor Tensor::copy_to<paddle::platform::float16>( template Tensor Tensor::copy_to<paddle::platform::float16>(
const PlaceType &target_place); const PlaceType &target_place) const;
template Tensor Tensor::copy_to<paddle::platform::bfloat16>( template Tensor Tensor::copy_to<paddle::platform::bfloat16>(
const PlaceType &target_place); const PlaceType &target_place) const;
template Tensor Tensor::copy_to<paddle::platform::complex64>( template Tensor Tensor::copy_to<paddle::platform::complex64>(
const PlaceType &target_place); const PlaceType &target_place) const;
template Tensor Tensor::copy_to<paddle::platform::complex128>( template Tensor Tensor::copy_to<paddle::platform::complex128>(
const PlaceType &target_place); const PlaceType &target_place) const;
template Tensor Tensor::copy_to<float>(const PlaceType &target_place); template Tensor Tensor::copy_to<float>(const PlaceType &target_place) const;
template Tensor Tensor::copy_to<double>(const PlaceType &target_place); template Tensor Tensor::copy_to<double>(const PlaceType &target_place) const;
template Tensor Tensor::copy_to<int64_t>(const PlaceType &target_place); template Tensor Tensor::copy_to<int64_t>(const PlaceType &target_place) const;
template Tensor Tensor::copy_to<int32_t>(const PlaceType &target_place); template Tensor Tensor::copy_to<int32_t>(const PlaceType &target_place) const;
template Tensor Tensor::copy_to<uint8_t>(const PlaceType &target_place); template Tensor Tensor::copy_to<uint8_t>(const PlaceType &target_place) const;
template Tensor Tensor::copy_to<int8_t>(const PlaceType &target_place); template Tensor Tensor::copy_to<int8_t>(const PlaceType &target_place) const;
template Tensor Tensor::copy_to<int16_t>(const PlaceType &target_place); template Tensor Tensor::copy_to<int16_t>(const PlaceType &target_place) const;
template Tensor Tensor::copy_to<bool>(const PlaceType &target_place); template Tensor Tensor::copy_to<bool>(const PlaceType &target_place) const;
template float *Tensor::data<float>() const; template float *Tensor::data<float>() const;
template double *Tensor::data<double>() const; template double *Tensor::data<double>() const;
...@@ -295,7 +299,7 @@ const PlaceType &Tensor::place() const { ...@@ -295,7 +299,7 @@ const PlaceType &Tensor::place() const {
return place_; return place_;
} }
Tensor Tensor::cast(const DataType &target_type) { Tensor Tensor::cast(const DataType &target_type) const {
GET_CASTED_TENSOR; GET_CASTED_TENSOR;
Tensor rlt = Tensor(place()); Tensor rlt = Tensor(place());
rlt.reshape(this->shape()); rlt.reshape(this->shape());
...@@ -342,7 +346,14 @@ Tensor Tensor::cast(const DataType &target_type) { ...@@ -342,7 +346,14 @@ Tensor Tensor::cast(const DataType &target_type) {
framework::VisitDataType( framework::VisitDataType(
dst_type, CastDataType<uint8_t>(*tensor, rlt_tensor_, ctx)); dst_type, CastDataType<uint8_t>(*tensor, rlt_tensor_, ctx));
break; break;
// TODO(JiabinYang): Support Complex later case framework::proto::VarType::COMPLEX64:
framework::VisitDataType(dst_type, CastDataType<platform::complex64>(
*tensor, rlt_tensor_, ctx));
break;
case framework::proto::VarType::COMPLEX128:
framework::VisitDataType(dst_type, CastDataType<platform::complex128>(
*tensor, rlt_tensor_, ctx));
break;
default: default:
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Data type (%s) is not supported when casting data type.", "Data type (%s) is not supported when casting data type.",
......
...@@ -25,7 +25,7 @@ paddle::Tensor InitCPUTensorForTest() { ...@@ -25,7 +25,7 @@ paddle::Tensor InitCPUTensorForTest() {
t1.reshape(tensor_shape); t1.reshape(tensor_shape);
auto* p_data_ptr = t1.mutable_data<T>(paddle::PlaceType::kCPU); auto* p_data_ptr = t1.mutable_data<T>(paddle::PlaceType::kCPU);
for (int64_t i = 0; i < t1.size(); i++) { for (int64_t i = 0; i < t1.size(); i++) {
p_data_ptr[i] = 5; p_data_ptr[i] = T(5);
} }
return t1; return t1;
} }
...@@ -36,7 +36,7 @@ void TestCopyTensor() { ...@@ -36,7 +36,7 @@ void TestCopyTensor() {
auto t1_cpu_cp = t1.template copy_to<T>(paddle::PlaceType::kCPU); auto t1_cpu_cp = t1.template copy_to<T>(paddle::PlaceType::kCPU);
CHECK((paddle::PlaceType::kCPU == t1_cpu_cp.place())); CHECK((paddle::PlaceType::kCPU == t1_cpu_cp.place()));
for (int64_t i = 0; i < t1.size(); i++) { for (int64_t i = 0; i < t1.size(); i++) {
CHECK_EQ(t1_cpu_cp.template data<T>()[i], 5); CHECK_EQ(t1_cpu_cp.template data<T>()[i], T(5));
} }
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
VLOG(2) << "Do GPU copy test"; VLOG(2) << "Do GPU copy test";
...@@ -48,7 +48,7 @@ void TestCopyTensor() { ...@@ -48,7 +48,7 @@ void TestCopyTensor() {
t1_gpu_cp.template copy_to<T>(paddle::PlaceType::kCPU); t1_gpu_cp.template copy_to<T>(paddle::PlaceType::kCPU);
CHECK((paddle::PlaceType::kCPU == t1_gpu_cp_cp_cpu.place())); CHECK((paddle::PlaceType::kCPU == t1_gpu_cp_cp_cpu.place()));
for (int64_t i = 0; i < t1.size(); i++) { for (int64_t i = 0; i < t1.size(); i++) {
CHECK_EQ(t1_gpu_cp_cp_cpu.template data<T>()[i], 5); CHECK_EQ(t1_gpu_cp_cp_cpu.template data<T>()[i], T(5));
} }
#endif #endif
} }
...@@ -99,16 +99,15 @@ void GroupTestCopy() { ...@@ -99,16 +99,15 @@ void GroupTestCopy() {
TestCopyTensor<float>(); TestCopyTensor<float>();
VLOG(2) << "Double cpu-cpu-gpu-gpu-cpu"; VLOG(2) << "Double cpu-cpu-gpu-gpu-cpu";
TestCopyTensor<double>(); TestCopyTensor<double>();
// TODO(JiabinYang): Support these test later VLOG(2) << "Fp16 cpu-cpu-gpu-gpu-cpu";
// VLOG(2) << "Fp16 cpu-cpu-gpu-gpu-cpu"; TestCopyTensor<paddle::platform::float16>();
// TestCopyTensor<paddle::platform::float16>(); VLOG(2) << "BF16 cpu-cpu-gpu-gpu-cpu";
// VLOG(2) << "BF16 cpu-cpu-gpu-gpu-cpu"; TestCopyTensor<paddle::platform::bfloat16>();
// TestCopyTensor<paddle::platform::bfloat16>(); VLOG(2) << "complex128 cpu-cpu-gpu-gpu-cpu";
// VLOG(2) << "complex128 cpu-cpu-gpu-gpu-cpu"; TestCopyTensor<paddle::platform::complex128>();
// TestCopyTensor<paddle::platform::complex128>(); VLOG(2) << "complex64 cpu-cpu-gpu-gpu-cpu";
// VLOG(2) << "complex64 cpu-cpu-gpu-gpu-cpu"; TestCopyTensor<paddle::platform::complex64>();
// TestCopyTensor<paddle::platform::complex64>(); VLOG(2) << "int cpu-cpu-gpu-gpu-cpu";
// VLOG(2) << "int cpu-cpu-gpu-gpu-cpu";
TestCopyTensor<int>(); TestCopyTensor<int>();
VLOG(2) << "int64 cpu-cpu-gpu-gpu-cpu"; VLOG(2) << "int64 cpu-cpu-gpu-gpu-cpu";
TestCopyTensor<int64_t>(); TestCopyTensor<int64_t>();
...@@ -139,6 +138,10 @@ void GroupTestCast() { ...@@ -139,6 +138,10 @@ void GroupTestCast() {
TestCast<uint8_t>(paddle::DataType::FLOAT32); TestCast<uint8_t>(paddle::DataType::FLOAT32);
VLOG(2) << "float cast"; VLOG(2) << "float cast";
TestCast<float>(paddle::DataType::FLOAT32); TestCast<float>(paddle::DataType::FLOAT32);
VLOG(2) << "complex64 cast";
TestCast<float>(paddle::DataType::FLOAT32);
VLOG(2) << "complex128 cast";
TestCast<float>(paddle::DataType::FLOAT32);
} }
void GroupTestDtype() { void GroupTestDtype() {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册