未验证 提交 ab1097cd 编写于 作者: D dzhwinter 提交者: GitHub

Feature/template (#13093)

* remove template operator

* "fix compile"

* "fix ci"

* "fix ci"
上级 d0c65bff
......@@ -46,7 +46,7 @@ struct CastDataLayout {
const std::vector<int> axis_;
template <typename T>
void operator()() {
void apply() {
auto place = ctx_->GetPlace();
if (platform::is_cpu_place(place)) {
......
......@@ -26,75 +26,40 @@ namespace framework {
extern proto::VarType::Type ToDataType(std::type_index type);
extern std::type_index ToTypeIndex(proto::VarType::Type type);
#if !defined(_WIN32)
template <typename Visitor>
inline void VisitDataType(proto::VarType::Type type, Visitor visitor) {
switch (type) {
case proto::VarType::FP16:
visitor.template operator()<platform::float16>();
visitor.template apply<platform::float16>();
break;
case proto::VarType::FP32:
visitor.template operator()<float>();
visitor.template apply<float>();
break;
case proto::VarType::FP64:
visitor.template operator()<double>();
visitor.template apply<double>();
break;
case proto::VarType::INT32:
visitor.template operator()<int>();
visitor.template apply<int>();
break;
case proto::VarType::INT64:
visitor.template operator()<int64_t>();
visitor.template apply<int64_t>();
break;
case proto::VarType::BOOL:
visitor.template operator()<bool>();
visitor.template apply<bool>();
break;
case proto::VarType::UINT8:
visitor.template operator()<uint8_t>();
visitor.template apply<uint8_t>();
break;
case proto::VarType::INT16:
visitor.template operator()<int16_t>();
visitor.template apply<int16_t>();
break;
case proto::VarType::INT8:
visitor.template operator()<int8_t>();
visitor.template apply<int8_t>();
break;
default:
PADDLE_THROW("Not supported %d", type);
}
}
#else
// the msvc compiler do not implement two-stage name lookup correctly.
template <typename Visitor>
inline void VisitDataType(proto::VarType::Type type, Visitor visitor) {
switch (type) {
case proto::VarType::FP16:
visitor.operator()<platform::float16>();
break;
case proto::VarType::FP32:
visitor.operator()<float>();
break;
case proto::VarType::FP64:
visitor.operator()<double>();
break;
case proto::VarType::INT32:
visitor.operator()<int>();
break;
case proto::VarType::INT64:
visitor.operator()<int64_t>();
break;
case proto::VarType::BOOL:
visitor.operator()<bool>();
break;
case proto::VarType::UINT8:
visitor.operator()<uint8_t>();
break;
case proto::VarType::INT16:
visitor.operator()<int16_t>();
break;
default:
PADDLE_THROW("Not supported %d", type);
}
}
#endif // _WIN32
extern std::string DataTypeToString(const proto::VarType::Type type);
extern size_t SizeOfType(std::type_index type);
......
......@@ -37,7 +37,7 @@ struct CastDataType {
const platform::DeviceContext* ctx_;
template <typename OutType>
void operator()() {
void apply() {
auto* in_begin = in_.data<InType>();
auto* in_end = in_begin + in_.numel();
auto* out_begin = out_->mutable_data<OutType>(in_.place());
......
......@@ -31,7 +31,7 @@ struct ReduceLoDTensor {
: src_tensors_(src), dst_tensor_(*dst) {}
template <typename T>
void operator()() const {
void apply() const {
PADDLE_ENFORCE(!src_tensors_.empty());
auto &t0 = *src_tensors_[0];
PADDLE_ENFORCE_NE(t0.numel(), 0);
......
......@@ -49,7 +49,7 @@ struct TensorCopyVisitor {
size_(size) {}
template <typename T>
void operator()() const {
void apply() const {
// TODO(Yancey1989): support other place
platform::CPUPlace cpu;
memory::Copy(cpu, dst_->mutable_data<T>(cpu) + dst_offset_, cpu,
......
......@@ -149,7 +149,7 @@ struct AnyDTypeVisitor {
: predicate_(predicate), tensor_(tensor), ctx_(ctx), out_(out) {}
template <typename T>
void operator()() const {
void apply() const {
auto t = EigenVector<T>::Flatten(tensor_);
auto o = EigenScalar<bool>::From(*out_);
// return any of predicate_(t) is true.
......@@ -302,7 +302,7 @@ struct DeserializedDataFunctor {
: buf_(buf), tensor_(tensor), place_(place) {}
template <typename T>
void operator()() {
void apply() {
*buf_ = tensor_->mutable_data<T>(place_);
}
......
......@@ -74,7 +74,7 @@ struct BeamSearchDecodeFunctor {
}
template <typename T>
void operator()() const;
void apply() const;
bool tensor_on_gpu_;
size_t beam_size_;
......@@ -88,7 +88,7 @@ struct BeamSearchDecodeFunctor {
};
template <typename T>
void BeamSearchDecodeFunctor::operator()() const {
void BeamSearchDecodeFunctor::apply() const {
BeamSearchDecoder<T> beam_search_decoder(beam_size_, end_id_);
// Check if the tensor is on GPU. If so, use the CPU copy instead
if (tensor_on_gpu_) {
......@@ -101,7 +101,7 @@ void BeamSearchDecodeFunctor::operator()() const {
}
template <>
void BeamSearchDecodeFunctor::operator()<bool>() const {
void BeamSearchDecodeFunctor::apply<bool>() const {
PADDLE_THROW("beam search decode op does not support bool!");
}
......
......@@ -37,7 +37,7 @@ struct CastOpFunctor {
: in_(in), out_(out), ctx_(ctx) {}
template <typename OutT>
void operator()() const {
void apply() const {
auto* in_begin = in_->data<InT>();
auto numel = in_->numel();
auto* in_end = in_begin + numel;
......
......@@ -33,7 +33,7 @@ struct AppendProposalsFunctor {
: out_(out), offset_(offset), to_add_(to_add) {}
template <typename T>
void operator()() const {
void apply() const {
auto *out_data = out_->data<T>();
auto *to_add_data = to_add_->data<T>();
memcpy(out_data + offset_, to_add_data, to_add_->numel() * sizeof(T));
......
......@@ -25,7 +25,7 @@ struct FillOpVisitor {
: tensor_(tensor), value_(value) {}
template <typename T>
void operator()() const {
void apply() const {
platform::CPUPlace cpu;
auto *data = tensor_->mutable_data<T>(cpu);
std::transform(value_.data(), value_.data() + tensor_->numel(), data,
......
......@@ -55,7 +55,7 @@ struct TensorSetConstantCPU {
TensorSetConstantCPU(framework::Tensor* tensor, float value)
: tensor_(tensor), value_(value) {}
template <typename T>
void operator()() const {
void apply() const {
auto cpu = platform::CPUPlace();
auto* begin = tensor_->mutable_data<T>(cpu);
std::fill(begin, begin + tensor_->numel(), static_cast<T>(value_));
......
......@@ -52,7 +52,7 @@ struct TensorSetConstantGPU {
: context_(context), tensor_(tensor), value_(value) {}
template <typename T>
void operator()() const {
void apply() const {
SetConstant<platform::CUDADeviceContext, T> functor;
functor(reinterpret_cast<const platform::CUDADeviceContext&>(context_),
tensor_, static_cast<T>(value_));
......
......@@ -41,7 +41,7 @@ struct OneHotOpCUDAFunctor {
: in_(in), out_(out), depth_(depth), ctx_(ctx) {}
template <typename OutT>
void operator()() const {
void apply() const {
auto* p_in_data = in_->data<InT>();
auto numel = in_->numel();
auto* p_out_data = out_->mutable_data<OutT>(ctx_.GetPlace());
......
......@@ -31,7 +31,7 @@ struct OneHotOpFunctor {
: in_(in), out_(out), depth_(depth), ctx_(ctx) {}
template <typename OutT>
void operator()() const {
void apply() const {
auto* p_in_data = in_->data<InT>();
auto numel = in_->numel();
auto* p_out_data = out_->mutable_data<OutT>(ctx_.GetPlace());
......
......@@ -99,7 +99,7 @@ struct SequenceMaskFunctor {
: ctx_(ctx), x_(x), y_(y), limits_(limits), maxlen_(maxlen) {}
template <typename Ty>
void operator()() const {
void apply() const {
auto *y_data = y_->mutable_data<Ty>(ctx_.GetPlace());
platform::ForRange<DeviceContext> for_range(ctx_, limits_);
for_range(SequenceMaskForRangeFunctor<Tx, Ty>(x_, y_data, maxlen_));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册