From 5e8e7fb6e6cf8d418e35f6af78fb969e012ee57c Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Thu, 30 Aug 2018 13:56:21 +0800 Subject: [PATCH] change data type --- paddle/fluid/framework/CMakeLists.txt | 2 +- paddle/fluid/framework/data_type.h | 22 ++++++++++++------- paddle/fluid/framework/data_type_transform.cc | 2 +- paddle/fluid/framework/tensor_util.cc | 8 ++++--- paddle/fluid/framework/tensor_util.h | 4 ++-- paddle/fluid/operators/cast_op.h | 6 +++++ 6 files changed, 29 insertions(+), 15 deletions(-) diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index cdb3a168b12..675018be087 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -9,7 +9,7 @@ function(windows_symbolic TARGET) if (NOT EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${src}.cc OR NOT EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${src}.cu) message(FATAL " ${src}.cc and ${src}.cu must exsits, and ${src}.cu must be symbolic file.") endif() - add_custom_command(OUTPUT .${src}.cu PRE_BUILD + add_custom_command(OUTPUT .${src}.cu COMMAND ${CMAKE_COMMAND} -E remove ${CMAKE_CURRENT_SOURCE_DIR}/.${src}.cu COMMAND ${CMAKE_COMMAND} -E copy "${CMAKE_CURRENT_SOURCE_DIR}/${src}.cc" "${CMAKE_CURRENT_SOURCE_DIR}/.${src}.cu" COMMENT "create hidden file of ${src}.cu") diff --git a/paddle/fluid/framework/data_type.h b/paddle/fluid/framework/data_type.h index 0cedc9b8361..84c2e7f2272 100644 --- a/paddle/fluid/framework/data_type.h +++ b/paddle/fluid/framework/data_type.h @@ -64,33 +64,39 @@ template inline void VisitDataType(proto::VarType::Type type, Visitor visitor) { switch (type) { case proto::VarType::FP16: - typename visitor.operator()(); + visitor.template apply(); break; case proto::VarType::FP32: - visitor.operator()(); + visitor.template apply(); break; case proto::VarType::FP64: - visitor.operator()(); + visitor.template apply(); break; case proto::VarType::INT32: - visitor.operator()(); + visitor.template apply(); break; case proto::VarType::INT64: - visitor.operator()(); + visitor.template apply(); break; case proto::VarType::BOOL: - visitor.operator()(); + visitor.template apply(); break; case proto::VarType::UINT8: - visitor.operator()(); + visitor.template apply(); break; case proto::VarType::INT16: - visitor.operator()(); + visitor.template apply(); break; default: PADDLE_THROW("Not supported %d", type); } } + +template +void* AnyCast(const InT* t) { + return static_cast(const_cast(t)); +} + #endif // _WIN32 extern std::string DataTypeToString(const proto::VarType::Type type); diff --git a/paddle/fluid/framework/data_type_transform.cc b/paddle/fluid/framework/data_type_transform.cc index 5a57ec20585..8213c82ec1f 100644 --- a/paddle/fluid/framework/data_type_transform.cc +++ b/paddle/fluid/framework/data_type_transform.cc @@ -37,7 +37,7 @@ struct CastDataType { const platform::DeviceContext* ctx_; template - void operator()() { + void apply()() { auto* in_begin = in_.data(); auto* in_end = in_begin + in_.numel(); auto* out_begin = out_->mutable_data(in_.place()); diff --git a/paddle/fluid/framework/tensor_util.cc b/paddle/fluid/framework/tensor_util.cc index ab693004cfb..5d1e72505d8 100644 --- a/paddle/fluid/framework/tensor_util.cc +++ b/paddle/fluid/framework/tensor_util.cc @@ -137,6 +137,7 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place, #endif } +/* template struct AnyDTypeVisitor { Predicate predicate_; @@ -149,7 +150,7 @@ struct AnyDTypeVisitor { : predicate_(predicate), tensor_(tensor), ctx_(ctx), out_(out) {} template - void operator()() const { + void apply()() const { auto t = EigenVector::Flatten(tensor_); auto o = EigenScalar::From(*out_); // return any of predicate_(t) is true. @@ -173,7 +174,7 @@ struct AnyVisitor : public boost::static_visitor { : tensor_(tensor), predicate_(std::move(predicate)) {} template - bool operator()(const Place& place) const { + bool apply()(const Place& place) const { framework::Tensor out; out.Resize({1}); out.mutable_data(place); @@ -240,6 +241,7 @@ bool TensorContainsInf(const framework::Tensor& tensor) { ContainsInfPredicate predicate; return Any(tensor, predicate); } +*/ void TensorToStream(std::ostream& os, const Tensor& tensor, const platform::DeviceContext& dev_ctx) { @@ -302,7 +304,7 @@ struct DeserializedDataFunctor { : buf_(buf), tensor_(tensor), place_(place) {} template - void operator()() { + void apply() { *buf_ = tensor_->mutable_data(place_); } diff --git a/paddle/fluid/framework/tensor_util.h b/paddle/fluid/framework/tensor_util.h index 4457382ade3..addf71f4dc8 100644 --- a/paddle/fluid/framework/tensor_util.h +++ b/paddle/fluid/framework/tensor_util.h @@ -57,8 +57,8 @@ void TensorToVector(const Tensor& src, const platform::DeviceContext& ctx, template void TesnorToVector(const Tensor& src, std::vector* dst); -bool TensorContainsNAN(const framework::Tensor& tensor); -bool TensorContainsInf(const framework::Tensor& tensor); +// bool TensorContainsNAN(const framework::Tensor& tensor); +// bool TensorContainsInf(const framework::Tensor& tensor); void TensorToStream(std::ostream& os, const Tensor& tensor, const platform::DeviceContext& dev_ctx); diff --git a/paddle/fluid/operators/cast_op.h b/paddle/fluid/operators/cast_op.h index 6220e57f594..abc209d58d0 100644 --- a/paddle/fluid/operators/cast_op.h +++ b/paddle/fluid/operators/cast_op.h @@ -54,11 +54,17 @@ class CastOpKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& context) const override { auto* in = context.Input("X"); auto* out = context.Output("Out"); +#if !defined(_MSC_VER) framework::VisitDataType( static_cast( context.Attr("out_dtype")), CastOpFunctor( in, out, context.template device_context())); +#else + auto type = static_cast( + context.Attr("out_dtype")); + trans +#endif // msvc } }; -- GitLab