未验证 提交 ab1097cd 编写于 作者: D dzhwinter 提交者: GitHub

Feature/template (#13093)

* remove template operator

* "fix compile"

* "fix ci"

* "fix ci"
上级 d0c65bff
...@@ -46,7 +46,7 @@ struct CastDataLayout { ...@@ -46,7 +46,7 @@ struct CastDataLayout {
const std::vector<int> axis_; const std::vector<int> axis_;
template <typename T> template <typename T>
void operator()() { void apply() {
auto place = ctx_->GetPlace(); auto place = ctx_->GetPlace();
if (platform::is_cpu_place(place)) { if (platform::is_cpu_place(place)) {
......
...@@ -26,75 +26,40 @@ namespace framework { ...@@ -26,75 +26,40 @@ namespace framework {
extern proto::VarType::Type ToDataType(std::type_index type); extern proto::VarType::Type ToDataType(std::type_index type);
extern std::type_index ToTypeIndex(proto::VarType::Type type); extern std::type_index ToTypeIndex(proto::VarType::Type type);
#if !defined(_WIN32)
template <typename Visitor> template <typename Visitor>
inline void VisitDataType(proto::VarType::Type type, Visitor visitor) { inline void VisitDataType(proto::VarType::Type type, Visitor visitor) {
switch (type) { switch (type) {
case proto::VarType::FP16: case proto::VarType::FP16:
visitor.template operator()<platform::float16>(); visitor.template apply<platform::float16>();
break; break;
case proto::VarType::FP32: case proto::VarType::FP32:
visitor.template operator()<float>(); visitor.template apply<float>();
break; break;
case proto::VarType::FP64: case proto::VarType::FP64:
visitor.template operator()<double>(); visitor.template apply<double>();
break; break;
case proto::VarType::INT32: case proto::VarType::INT32:
visitor.template operator()<int>(); visitor.template apply<int>();
break; break;
case proto::VarType::INT64: case proto::VarType::INT64:
visitor.template operator()<int64_t>(); visitor.template apply<int64_t>();
break; break;
case proto::VarType::BOOL: case proto::VarType::BOOL:
visitor.template operator()<bool>(); visitor.template apply<bool>();
break; break;
case proto::VarType::UINT8: case proto::VarType::UINT8:
visitor.template operator()<uint8_t>(); visitor.template apply<uint8_t>();
break; break;
case proto::VarType::INT16: case proto::VarType::INT16:
visitor.template operator()<int16_t>(); visitor.template apply<int16_t>();
break; break;
case proto::VarType::INT8: case proto::VarType::INT8:
visitor.template operator()<int8_t>(); visitor.template apply<int8_t>();
break; break;
default: default:
PADDLE_THROW("Not supported %d", type); PADDLE_THROW("Not supported %d", type);
} }
} }
#else
// the msvc compiler do not implement two-stage name lookup correctly.
template <typename Visitor>
inline void VisitDataType(proto::VarType::Type type, Visitor visitor) {
switch (type) {
case proto::VarType::FP16:
visitor.operator()<platform::float16>();
break;
case proto::VarType::FP32:
visitor.operator()<float>();
break;
case proto::VarType::FP64:
visitor.operator()<double>();
break;
case proto::VarType::INT32:
visitor.operator()<int>();
break;
case proto::VarType::INT64:
visitor.operator()<int64_t>();
break;
case proto::VarType::BOOL:
visitor.operator()<bool>();
break;
case proto::VarType::UINT8:
visitor.operator()<uint8_t>();
break;
case proto::VarType::INT16:
visitor.operator()<int16_t>();
break;
default:
PADDLE_THROW("Not supported %d", type);
}
}
#endif // _WIN32
extern std::string DataTypeToString(const proto::VarType::Type type); extern std::string DataTypeToString(const proto::VarType::Type type);
extern size_t SizeOfType(std::type_index type); extern size_t SizeOfType(std::type_index type);
......
...@@ -37,7 +37,7 @@ struct CastDataType { ...@@ -37,7 +37,7 @@ struct CastDataType {
const platform::DeviceContext* ctx_; const platform::DeviceContext* ctx_;
template <typename OutType> template <typename OutType>
void operator()() { void apply() {
auto* in_begin = in_.data<InType>(); auto* in_begin = in_.data<InType>();
auto* in_end = in_begin + in_.numel(); auto* in_end = in_begin + in_.numel();
auto* out_begin = out_->mutable_data<OutType>(in_.place()); auto* out_begin = out_->mutable_data<OutType>(in_.place());
......
...@@ -31,7 +31,7 @@ struct ReduceLoDTensor { ...@@ -31,7 +31,7 @@ struct ReduceLoDTensor {
: src_tensors_(src), dst_tensor_(*dst) {} : src_tensors_(src), dst_tensor_(*dst) {}
template <typename T> template <typename T>
void operator()() const { void apply() const {
PADDLE_ENFORCE(!src_tensors_.empty()); PADDLE_ENFORCE(!src_tensors_.empty());
auto &t0 = *src_tensors_[0]; auto &t0 = *src_tensors_[0];
PADDLE_ENFORCE_NE(t0.numel(), 0); PADDLE_ENFORCE_NE(t0.numel(), 0);
......
...@@ -49,7 +49,7 @@ struct TensorCopyVisitor { ...@@ -49,7 +49,7 @@ struct TensorCopyVisitor {
size_(size) {} size_(size) {}
template <typename T> template <typename T>
void operator()() const { void apply() const {
// TODO(Yancey1989): support other place // TODO(Yancey1989): support other place
platform::CPUPlace cpu; platform::CPUPlace cpu;
memory::Copy(cpu, dst_->mutable_data<T>(cpu) + dst_offset_, cpu, memory::Copy(cpu, dst_->mutable_data<T>(cpu) + dst_offset_, cpu,
......
...@@ -149,7 +149,7 @@ struct AnyDTypeVisitor { ...@@ -149,7 +149,7 @@ struct AnyDTypeVisitor {
: predicate_(predicate), tensor_(tensor), ctx_(ctx), out_(out) {} : predicate_(predicate), tensor_(tensor), ctx_(ctx), out_(out) {}
template <typename T> template <typename T>
void operator()() const { void apply() const {
auto t = EigenVector<T>::Flatten(tensor_); auto t = EigenVector<T>::Flatten(tensor_);
auto o = EigenScalar<bool>::From(*out_); auto o = EigenScalar<bool>::From(*out_);
// return any of predicate_(t) is true. // return any of predicate_(t) is true.
...@@ -302,7 +302,7 @@ struct DeserializedDataFunctor { ...@@ -302,7 +302,7 @@ struct DeserializedDataFunctor {
: buf_(buf), tensor_(tensor), place_(place) {} : buf_(buf), tensor_(tensor), place_(place) {}
template <typename T> template <typename T>
void operator()() { void apply() {
*buf_ = tensor_->mutable_data<T>(place_); *buf_ = tensor_->mutable_data<T>(place_);
} }
......
...@@ -74,7 +74,7 @@ struct BeamSearchDecodeFunctor { ...@@ -74,7 +74,7 @@ struct BeamSearchDecodeFunctor {
} }
template <typename T> template <typename T>
void operator()() const; void apply() const;
bool tensor_on_gpu_; bool tensor_on_gpu_;
size_t beam_size_; size_t beam_size_;
...@@ -88,7 +88,7 @@ struct BeamSearchDecodeFunctor { ...@@ -88,7 +88,7 @@ struct BeamSearchDecodeFunctor {
}; };
template <typename T> template <typename T>
void BeamSearchDecodeFunctor::operator()() const { void BeamSearchDecodeFunctor::apply() const {
BeamSearchDecoder<T> beam_search_decoder(beam_size_, end_id_); BeamSearchDecoder<T> beam_search_decoder(beam_size_, end_id_);
// Check if the tensor is on GPU. If so, use the CPU copy instead // Check if the tensor is on GPU. If so, use the CPU copy instead
if (tensor_on_gpu_) { if (tensor_on_gpu_) {
...@@ -101,7 +101,7 @@ void BeamSearchDecodeFunctor::operator()() const { ...@@ -101,7 +101,7 @@ void BeamSearchDecodeFunctor::operator()() const {
} }
template <> template <>
void BeamSearchDecodeFunctor::operator()<bool>() const { void BeamSearchDecodeFunctor::apply<bool>() const {
PADDLE_THROW("beam search decode op does not support bool!"); PADDLE_THROW("beam search decode op does not support bool!");
} }
......
...@@ -37,7 +37,7 @@ struct CastOpFunctor { ...@@ -37,7 +37,7 @@ struct CastOpFunctor {
: in_(in), out_(out), ctx_(ctx) {} : in_(in), out_(out), ctx_(ctx) {}
template <typename OutT> template <typename OutT>
void operator()() const { void apply() const {
auto* in_begin = in_->data<InT>(); auto* in_begin = in_->data<InT>();
auto numel = in_->numel(); auto numel = in_->numel();
auto* in_end = in_begin + numel; auto* in_end = in_begin + numel;
......
...@@ -33,7 +33,7 @@ struct AppendProposalsFunctor { ...@@ -33,7 +33,7 @@ struct AppendProposalsFunctor {
: out_(out), offset_(offset), to_add_(to_add) {} : out_(out), offset_(offset), to_add_(to_add) {}
template <typename T> template <typename T>
void operator()() const { void apply() const {
auto *out_data = out_->data<T>(); auto *out_data = out_->data<T>();
auto *to_add_data = to_add_->data<T>(); auto *to_add_data = to_add_->data<T>();
memcpy(out_data + offset_, to_add_data, to_add_->numel() * sizeof(T)); memcpy(out_data + offset_, to_add_data, to_add_->numel() * sizeof(T));
......
...@@ -25,7 +25,7 @@ struct FillOpVisitor { ...@@ -25,7 +25,7 @@ struct FillOpVisitor {
: tensor_(tensor), value_(value) {} : tensor_(tensor), value_(value) {}
template <typename T> template <typename T>
void operator()() const { void apply() const {
platform::CPUPlace cpu; platform::CPUPlace cpu;
auto *data = tensor_->mutable_data<T>(cpu); auto *data = tensor_->mutable_data<T>(cpu);
std::transform(value_.data(), value_.data() + tensor_->numel(), data, std::transform(value_.data(), value_.data() + tensor_->numel(), data,
......
...@@ -55,7 +55,7 @@ struct TensorSetConstantCPU { ...@@ -55,7 +55,7 @@ struct TensorSetConstantCPU {
TensorSetConstantCPU(framework::Tensor* tensor, float value) TensorSetConstantCPU(framework::Tensor* tensor, float value)
: tensor_(tensor), value_(value) {} : tensor_(tensor), value_(value) {}
template <typename T> template <typename T>
void operator()() const { void apply() const {
auto cpu = platform::CPUPlace(); auto cpu = platform::CPUPlace();
auto* begin = tensor_->mutable_data<T>(cpu); auto* begin = tensor_->mutable_data<T>(cpu);
std::fill(begin, begin + tensor_->numel(), static_cast<T>(value_)); std::fill(begin, begin + tensor_->numel(), static_cast<T>(value_));
......
...@@ -52,7 +52,7 @@ struct TensorSetConstantGPU { ...@@ -52,7 +52,7 @@ struct TensorSetConstantGPU {
: context_(context), tensor_(tensor), value_(value) {} : context_(context), tensor_(tensor), value_(value) {}
template <typename T> template <typename T>
void operator()() const { void apply() const {
SetConstant<platform::CUDADeviceContext, T> functor; SetConstant<platform::CUDADeviceContext, T> functor;
functor(reinterpret_cast<const platform::CUDADeviceContext&>(context_), functor(reinterpret_cast<const platform::CUDADeviceContext&>(context_),
tensor_, static_cast<T>(value_)); tensor_, static_cast<T>(value_));
......
...@@ -41,7 +41,7 @@ struct OneHotOpCUDAFunctor { ...@@ -41,7 +41,7 @@ struct OneHotOpCUDAFunctor {
: in_(in), out_(out), depth_(depth), ctx_(ctx) {} : in_(in), out_(out), depth_(depth), ctx_(ctx) {}
template <typename OutT> template <typename OutT>
void operator()() const { void apply() const {
auto* p_in_data = in_->data<InT>(); auto* p_in_data = in_->data<InT>();
auto numel = in_->numel(); auto numel = in_->numel();
auto* p_out_data = out_->mutable_data<OutT>(ctx_.GetPlace()); auto* p_out_data = out_->mutable_data<OutT>(ctx_.GetPlace());
......
...@@ -31,7 +31,7 @@ struct OneHotOpFunctor { ...@@ -31,7 +31,7 @@ struct OneHotOpFunctor {
: in_(in), out_(out), depth_(depth), ctx_(ctx) {} : in_(in), out_(out), depth_(depth), ctx_(ctx) {}
template <typename OutT> template <typename OutT>
void operator()() const { void apply() const {
auto* p_in_data = in_->data<InT>(); auto* p_in_data = in_->data<InT>();
auto numel = in_->numel(); auto numel = in_->numel();
auto* p_out_data = out_->mutable_data<OutT>(ctx_.GetPlace()); auto* p_out_data = out_->mutable_data<OutT>(ctx_.GetPlace());
......
...@@ -99,7 +99,7 @@ struct SequenceMaskFunctor { ...@@ -99,7 +99,7 @@ struct SequenceMaskFunctor {
: ctx_(ctx), x_(x), y_(y), limits_(limits), maxlen_(maxlen) {} : ctx_(ctx), x_(x), y_(y), limits_(limits), maxlen_(maxlen) {}
template <typename Ty> template <typename Ty>
void operator()() const { void apply() const {
auto *y_data = y_->mutable_data<Ty>(ctx_.GetPlace()); auto *y_data = y_->mutable_data<Ty>(ctx_.GetPlace());
platform::ForRange<DeviceContext> for_range(ctx_, limits_); platform::ForRange<DeviceContext> for_range(ctx_, limits_);
for_range(SequenceMaskForRangeFunctor<Tx, Ty>(x_, y_data, maxlen_)); for_range(SequenceMaskForRangeFunctor<Tx, Ty>(x_, y_data, maxlen_));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册