提交 5e8e7fb6 编写于 作者: D dzhwinter

change data type

上级 f5329d65
...@@ -9,7 +9,7 @@ function(windows_symbolic TARGET) ...@@ -9,7 +9,7 @@ function(windows_symbolic TARGET)
if (NOT EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${src}.cc OR NOT EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${src}.cu) if (NOT EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${src}.cc OR NOT EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${src}.cu)
message(FATAL " ${src}.cc and ${src}.cu must exsits, and ${src}.cu must be symbolic file.") message(FATAL " ${src}.cc and ${src}.cu must exsits, and ${src}.cu must be symbolic file.")
endif() endif()
add_custom_command(OUTPUT .${src}.cu PRE_BUILD add_custom_command(OUTPUT .${src}.cu
COMMAND ${CMAKE_COMMAND} -E remove ${CMAKE_CURRENT_SOURCE_DIR}/.${src}.cu COMMAND ${CMAKE_COMMAND} -E remove ${CMAKE_CURRENT_SOURCE_DIR}/.${src}.cu
COMMAND ${CMAKE_COMMAND} -E copy "${CMAKE_CURRENT_SOURCE_DIR}/${src}.cc" "${CMAKE_CURRENT_SOURCE_DIR}/.${src}.cu" COMMAND ${CMAKE_COMMAND} -E copy "${CMAKE_CURRENT_SOURCE_DIR}/${src}.cc" "${CMAKE_CURRENT_SOURCE_DIR}/.${src}.cu"
COMMENT "create hidden file of ${src}.cu") COMMENT "create hidden file of ${src}.cu")
......
...@@ -64,33 +64,39 @@ template <typename Visitor> ...@@ -64,33 +64,39 @@ template <typename Visitor>
inline void VisitDataType(proto::VarType::Type type, Visitor visitor) { inline void VisitDataType(proto::VarType::Type type, Visitor visitor) {
switch (type) { switch (type) {
case proto::VarType::FP16: case proto::VarType::FP16:
typename visitor.operator()<platform::float16>(); visitor.template apply<platform::float16>();
break; break;
case proto::VarType::FP32: case proto::VarType::FP32:
visitor.operator()<float>(); visitor.template apply<float>();
break; break;
case proto::VarType::FP64: case proto::VarType::FP64:
visitor.operator()<double>(); visitor.template apply<double>();
break; break;
case proto::VarType::INT32: case proto::VarType::INT32:
visitor.operator()<int>(); visitor.template apply<int>();
break; break;
case proto::VarType::INT64: case proto::VarType::INT64:
visitor.operator()<int64_t>(); visitor.template apply<int64_t>();
break; break;
case proto::VarType::BOOL: case proto::VarType::BOOL:
visitor.operator()<bool>(); visitor.template apply<bool>();
break; break;
case proto::VarType::UINT8: case proto::VarType::UINT8:
visitor.operator()<uint8_t>(); visitor.template apply<uint8_t>();
break; break;
case proto::VarType::INT16: case proto::VarType::INT16:
visitor.operator()<int16_t>(); visitor.template apply<int16_t>();
break; break;
default: default:
PADDLE_THROW("Not supported %d", type); PADDLE_THROW("Not supported %d", type);
} }
} }
template <typename InT>
void* AnyCast(const InT* t) {
return static_cast<void*>(const_cast<InT*>(t));
}
#endif // _WIN32 #endif // _WIN32
extern std::string DataTypeToString(const proto::VarType::Type type); extern std::string DataTypeToString(const proto::VarType::Type type);
......
...@@ -37,7 +37,7 @@ struct CastDataType { ...@@ -37,7 +37,7 @@ struct CastDataType {
const platform::DeviceContext* ctx_; const platform::DeviceContext* ctx_;
template <typename OutType> template <typename OutType>
void operator()() { void apply()() {
auto* in_begin = in_.data<InType>(); auto* in_begin = in_.data<InType>();
auto* in_end = in_begin + in_.numel(); auto* in_end = in_begin + in_.numel();
auto* out_begin = out_->mutable_data<OutType>(in_.place()); auto* out_begin = out_->mutable_data<OutType>(in_.place());
......
...@@ -137,6 +137,7 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place, ...@@ -137,6 +137,7 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place,
#endif #endif
} }
/*
template <typename Predicate, typename DevCtx> template <typename Predicate, typename DevCtx>
struct AnyDTypeVisitor { struct AnyDTypeVisitor {
Predicate predicate_; Predicate predicate_;
...@@ -149,7 +150,7 @@ struct AnyDTypeVisitor { ...@@ -149,7 +150,7 @@ struct AnyDTypeVisitor {
: predicate_(predicate), tensor_(tensor), ctx_(ctx), out_(out) {} : predicate_(predicate), tensor_(tensor), ctx_(ctx), out_(out) {}
template <typename T> template <typename T>
void operator()() const { void apply()() const {
auto t = EigenVector<T>::Flatten(tensor_); auto t = EigenVector<T>::Flatten(tensor_);
auto o = EigenScalar<bool>::From(*out_); auto o = EigenScalar<bool>::From(*out_);
// return any of predicate_(t) is true. // return any of predicate_(t) is true.
...@@ -173,7 +174,7 @@ struct AnyVisitor : public boost::static_visitor<bool> { ...@@ -173,7 +174,7 @@ struct AnyVisitor : public boost::static_visitor<bool> {
: tensor_(tensor), predicate_(std::move(predicate)) {} : tensor_(tensor), predicate_(std::move(predicate)) {}
template <typename Place> template <typename Place>
bool operator()(const Place& place) const { bool apply()(const Place& place) const {
framework::Tensor out; framework::Tensor out;
out.Resize({1}); out.Resize({1});
out.mutable_data<bool>(place); out.mutable_data<bool>(place);
...@@ -240,6 +241,7 @@ bool TensorContainsInf(const framework::Tensor& tensor) { ...@@ -240,6 +241,7 @@ bool TensorContainsInf(const framework::Tensor& tensor) {
ContainsInfPredicate predicate; ContainsInfPredicate predicate;
return Any(tensor, predicate); return Any(tensor, predicate);
} }
*/
void TensorToStream(std::ostream& os, const Tensor& tensor, void TensorToStream(std::ostream& os, const Tensor& tensor,
const platform::DeviceContext& dev_ctx) { const platform::DeviceContext& dev_ctx) {
...@@ -302,7 +304,7 @@ struct DeserializedDataFunctor { ...@@ -302,7 +304,7 @@ struct DeserializedDataFunctor {
: buf_(buf), tensor_(tensor), place_(place) {} : buf_(buf), tensor_(tensor), place_(place) {}
template <typename T> template <typename T>
void operator()() { void apply() {
*buf_ = tensor_->mutable_data<T>(place_); *buf_ = tensor_->mutable_data<T>(place_);
} }
......
...@@ -57,8 +57,8 @@ void TensorToVector(const Tensor& src, const platform::DeviceContext& ctx, ...@@ -57,8 +57,8 @@ void TensorToVector(const Tensor& src, const platform::DeviceContext& ctx,
template <typename T> template <typename T>
void TesnorToVector(const Tensor& src, std::vector<T>* dst); void TesnorToVector(const Tensor& src, std::vector<T>* dst);
bool TensorContainsNAN(const framework::Tensor& tensor); // bool TensorContainsNAN(const framework::Tensor& tensor);
bool TensorContainsInf(const framework::Tensor& tensor); // bool TensorContainsInf(const framework::Tensor& tensor);
void TensorToStream(std::ostream& os, const Tensor& tensor, void TensorToStream(std::ostream& os, const Tensor& tensor,
const platform::DeviceContext& dev_ctx); const platform::DeviceContext& dev_ctx);
......
...@@ -54,11 +54,17 @@ class CastOpKernel : public framework::OpKernel<InT> { ...@@ -54,11 +54,17 @@ class CastOpKernel : public framework::OpKernel<InT> {
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<framework::Tensor>("X"); auto* in = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out"); auto* out = context.Output<framework::Tensor>("Out");
#if !defined(_MSC_VER)
framework::VisitDataType( framework::VisitDataType(
static_cast<framework::proto::VarType::Type>( static_cast<framework::proto::VarType::Type>(
context.Attr<int>("out_dtype")), context.Attr<int>("out_dtype")),
CastOpFunctor<DeviceContext, InT>( CastOpFunctor<DeviceContext, InT>(
in, out, context.template device_context<DeviceContext>())); in, out, context.template device_context<DeviceContext>()));
#else
auto type = static_cast<framework::proto::VarType::Type>(
context.Attr<int>("out_dtype"));
trans
#endif // msvc
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册