提交 5e8e7fb6 编写于 作者: D dzhwinter

change data type

上级 f5329d65
......@@ -9,7 +9,7 @@ function(windows_symbolic TARGET)
if (NOT EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${src}.cc OR NOT EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${src}.cu)
message(FATAL " ${src}.cc and ${src}.cu must exsits, and ${src}.cu must be symbolic file.")
endif()
add_custom_command(OUTPUT .${src}.cu PRE_BUILD
add_custom_command(OUTPUT .${src}.cu
COMMAND ${CMAKE_COMMAND} -E remove ${CMAKE_CURRENT_SOURCE_DIR}/.${src}.cu
COMMAND ${CMAKE_COMMAND} -E copy "${CMAKE_CURRENT_SOURCE_DIR}/${src}.cc" "${CMAKE_CURRENT_SOURCE_DIR}/.${src}.cu"
COMMENT "create hidden file of ${src}.cu")
......
......@@ -64,33 +64,39 @@ template <typename Visitor>
inline void VisitDataType(proto::VarType::Type type, Visitor visitor) {
switch (type) {
case proto::VarType::FP16:
typename visitor.operator()<platform::float16>();
visitor.template apply<platform::float16>();
break;
case proto::VarType::FP32:
visitor.operator()<float>();
visitor.template apply<float>();
break;
case proto::VarType::FP64:
visitor.operator()<double>();
visitor.template apply<double>();
break;
case proto::VarType::INT32:
visitor.operator()<int>();
visitor.template apply<int>();
break;
case proto::VarType::INT64:
visitor.operator()<int64_t>();
visitor.template apply<int64_t>();
break;
case proto::VarType::BOOL:
visitor.operator()<bool>();
visitor.template apply<bool>();
break;
case proto::VarType::UINT8:
visitor.operator()<uint8_t>();
visitor.template apply<uint8_t>();
break;
case proto::VarType::INT16:
visitor.operator()<int16_t>();
visitor.template apply<int16_t>();
break;
default:
PADDLE_THROW("Not supported %d", type);
}
}
template <typename InT>
void* AnyCast(const InT* t) {
return static_cast<void*>(const_cast<InT*>(t));
}
#endif // _WIN32
extern std::string DataTypeToString(const proto::VarType::Type type);
......
......@@ -37,7 +37,7 @@ struct CastDataType {
const platform::DeviceContext* ctx_;
template <typename OutType>
void operator()() {
void apply()() {
auto* in_begin = in_.data<InType>();
auto* in_end = in_begin + in_.numel();
auto* out_begin = out_->mutable_data<OutType>(in_.place());
......
......@@ -137,6 +137,7 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place,
#endif
}
/*
template <typename Predicate, typename DevCtx>
struct AnyDTypeVisitor {
Predicate predicate_;
......@@ -149,7 +150,7 @@ struct AnyDTypeVisitor {
: predicate_(predicate), tensor_(tensor), ctx_(ctx), out_(out) {}
template <typename T>
void operator()() const {
void apply()() const {
auto t = EigenVector<T>::Flatten(tensor_);
auto o = EigenScalar<bool>::From(*out_);
// return any of predicate_(t) is true.
......@@ -173,7 +174,7 @@ struct AnyVisitor : public boost::static_visitor<bool> {
: tensor_(tensor), predicate_(std::move(predicate)) {}
template <typename Place>
bool operator()(const Place& place) const {
bool apply()(const Place& place) const {
framework::Tensor out;
out.Resize({1});
out.mutable_data<bool>(place);
......@@ -240,6 +241,7 @@ bool TensorContainsInf(const framework::Tensor& tensor) {
ContainsInfPredicate predicate;
return Any(tensor, predicate);
}
*/
void TensorToStream(std::ostream& os, const Tensor& tensor,
const platform::DeviceContext& dev_ctx) {
......@@ -302,7 +304,7 @@ struct DeserializedDataFunctor {
: buf_(buf), tensor_(tensor), place_(place) {}
template <typename T>
void operator()() {
void apply() {
*buf_ = tensor_->mutable_data<T>(place_);
}
......
......@@ -57,8 +57,8 @@ void TensorToVector(const Tensor& src, const platform::DeviceContext& ctx,
template <typename T>
void TesnorToVector(const Tensor& src, std::vector<T>* dst);
bool TensorContainsNAN(const framework::Tensor& tensor);
bool TensorContainsInf(const framework::Tensor& tensor);
// bool TensorContainsNAN(const framework::Tensor& tensor);
// bool TensorContainsInf(const framework::Tensor& tensor);
void TensorToStream(std::ostream& os, const Tensor& tensor,
const platform::DeviceContext& dev_ctx);
......
......@@ -54,11 +54,17 @@ class CastOpKernel : public framework::OpKernel<InT> {
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out");
#if !defined(_MSC_VER)
framework::VisitDataType(
static_cast<framework::proto::VarType::Type>(
context.Attr<int>("out_dtype")),
CastOpFunctor<DeviceContext, InT>(
in, out, context.template device_context<DeviceContext>()));
#else
auto type = static_cast<framework::proto::VarType::Type>(
context.Attr<int>("out_dtype"));
trans
#endif // msvc
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册