提交 4f4daa4b 编写于 作者: N nhzlx

cherry-pick from feature/anakin-engine: add data type for zero copy #16313

1. refine anakin engine
2. add data type for zero copy

align dev branch and PaddlePaddle:feature/anakin-engine brach
the cudnn workspace modify was not included for now, because we use a hard code way
in feature/anakin-engine branch. There should be a better way to implement it,
and subsequent submissions will be made.

test=develop
上级 07dcf285
......@@ -71,6 +71,7 @@ void AnakinEngine<TargetT, PrecisionType, RunType>::Execute(
const std::map<std::string, framework::LoDTensor *> &inputs,
const std::map<std::string, framework::LoDTensor *> &outputs,
cudaStream_t stream) {
cudaDeviceSynchronize();
for (const auto &input : inputs) {
auto *tensor = input.second;
auto *data = tensor->data<float>();
......
......@@ -74,6 +74,19 @@ T *ZeroCopyTensor::data(PaddlePlace *place, int *size) const {
return res;
}
PaddleDType ZeroCopyTensor::type() {
EAGER_GET_TENSOR;
auto type = tensor->type();
if (type == framework::proto::VarType::FP32) {
return PaddleDType::FLOAT32;
} else if (type == framework::proto::VarType::INT64) {
return PaddleDType::INT64;
} else {
LOG(ERROR) << "unknown type, only support float32 and int64 now.";
}
return PaddleDType::FLOAT32;
}
template <typename T>
void ZeroCopyTensor::copy_from_cpu(const T *data) {
EAGER_GET_TENSOR;
......@@ -119,6 +132,7 @@ void ZeroCopyTensor::copy_to_cpu(T *data) {
static_cast<const platform::CUDADeviceContext *>(pool.Get(gpu_place));
memory::Copy(platform::CPUPlace(), static_cast<void *>(data), gpu_place,
t_data, ele_num * sizeof(T), dev_ctx->stream());
cudaDeviceSynchronize();
#else
PADDLE_THROW("Not compile with CUDA, should not reach here.");
#endif
......
......@@ -177,6 +177,8 @@ class ZeroCopyTensor {
device_ = device;
}
PaddleDType type();
protected:
explicit ZeroCopyTensor(void* scope) : scope_{scope} {}
void SetName(const std::string& name) { name_ = name; }
......@@ -191,6 +193,7 @@ class ZeroCopyTensor {
// performance.
mutable void* tensor_{nullptr};
PaddlePlace place_;
PaddleDType dtype_;
int device_;
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册