未验证 提交 8a7efc78 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #15882 from sfraczek/unique_ptr_dereference

Change *(smart_ptr.get()) -> *smart_ptr
...@@ -122,7 +122,7 @@ void BeamSearchDecoder<T>::ConvertSentenceVectorToLodTensor( ...@@ -122,7 +122,7 @@ void BeamSearchDecoder<T>::ConvertSentenceVectorToLodTensor(
auto cpu_place = std::unique_ptr<paddle::platform::CPUPlace>( auto cpu_place = std::unique_ptr<paddle::platform::CPUPlace>(
new paddle::platform::CPUPlace()); new paddle::platform::CPUPlace());
paddle::platform::CPUDeviceContext cpu_ctx(*cpu_place.get()); paddle::platform::CPUDeviceContext cpu_ctx(*cpu_place);
framework::LoD lod; framework::LoD lod;
lod.push_back(source_level_lod); lod.push_back(source_level_lod);
......
...@@ -225,7 +225,7 @@ void eltwise_grad(const framework::ExecutionContext &ctx, ...@@ -225,7 +225,7 @@ void eltwise_grad(const framework::ExecutionContext &ctx,
std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(key_src_mem)); std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(key_src_mem));
PADDLE_ENFORCE(src_memory != nullptr, PADDLE_ENFORCE(src_memory != nullptr,
"Fail to find src_memory in device context"); "Fail to find src_memory in device context");
src_memory->set_data_handle(*p_src_data.get()); src_memory->set_data_handle(*p_src_data);
std::shared_ptr<memory> diff_src_memory; std::shared_ptr<memory> diff_src_memory;
......
...@@ -198,7 +198,7 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -198,7 +198,7 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
} }
// push primitive to stream and wait until it's executed // push primitive to stream and wait until it's executed
std::vector<mkldnn::primitive> pipeline{*(pool_p.get())}; std::vector<mkldnn::primitive> pipeline{*pool_p};
stream(stream::kind::eager).submit(pipeline).wait(); stream(stream::kind::eager).submit(pipeline).wait();
output->set_layout(DataLayout::kMKLDNN); output->set_layout(DataLayout::kMKLDNN);
...@@ -367,8 +367,7 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -367,8 +367,7 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
dev_ctx.SetBlob(key_pool_diff_dst_mem_p, diff_dst_memory); dev_ctx.SetBlob(key_pool_diff_dst_mem_p, diff_dst_memory);
pool_bwd_p = std::make_shared<pooling_backward>( pool_bwd_p = std::make_shared<pooling_backward>(
pool_bwd_pd, *(diff_dst_memory.get()), *workspace_memory, pool_bwd_pd, *diff_dst_memory, *workspace_memory, *diff_src_memory);
*(diff_src_memory));
dev_ctx.SetBlob(key_pool_bwd_p, pool_bwd_p); dev_ctx.SetBlob(key_pool_bwd_p, pool_bwd_p);
} else { } else {
...@@ -404,7 +403,7 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -404,7 +403,7 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
if (is_diff_dst_reordered) { if (is_diff_dst_reordered) {
pipeline.push_back(reorder_diff_dst); pipeline.push_back(reorder_diff_dst);
} }
pipeline.push_back(*(pool_bwd_p.get())); pipeline.push_back(*pool_bwd_p);
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
in_x_grad->set_layout(DataLayout::kMKLDNN); in_x_grad->set_layout(DataLayout::kMKLDNN);
......
...@@ -66,8 +66,7 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler { ...@@ -66,8 +66,7 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
"Fail to find softmax primitive in device context"); "Fail to find softmax primitive in device context");
if (softmax_p == nullptr) { if (softmax_p == nullptr) {
softmax_p = std::make_shared<mkldnn::softmax_forward>( softmax_p = std::make_shared<mkldnn::softmax_forward>(
*(softmax_pd_.get()), *softmax_pd_, *(static_cast<mkldnn::memory*>(src_memory_p.get())),
*(static_cast<mkldnn::memory*>(src_memory_p.get())),
*(static_cast<mkldnn::memory*>(dst_memory_p.get()))); *(static_cast<mkldnn::memory*>(dst_memory_p.get())));
dev_ctx_.SetBlob(prim_key, softmax_p); dev_ctx_.SetBlob(prim_key, softmax_p);
} else { } else {
...@@ -88,8 +87,8 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler { ...@@ -88,8 +87,8 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
"Fail to find softmax backward primitive in device context"); "Fail to find softmax backward primitive in device context");
if (softmax_bwd_p == nullptr) { if (softmax_bwd_p == nullptr) {
softmax_bwd_p = std::make_shared<mkldnn::softmax_backward>( softmax_bwd_p = std::make_shared<mkldnn::softmax_backward>(
*softmax_bwd_pd_, *(dst_memory_p.get()), *(diff_dst_memory_p.get()), *softmax_bwd_pd_, *dst_memory_p, *diff_dst_memory_p,
*(diff_src_memory_p.get())); *diff_src_memory_p);
dev_ctx_.SetBlob(prim_key, softmax_bwd_p); dev_ctx_.SetBlob(prim_key, softmax_bwd_p);
} else { } else {
is_reusing_ = true; is_reusing_ = true;
......
...@@ -160,7 +160,7 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -160,7 +160,7 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto get_selected_row = [&](size_t i) -> const SelectedRows& { auto get_selected_row = [&](size_t i) -> const SelectedRows& {
if (i == 0 && in0) { if (i == 0 && in0) {
return *in0.get(); return *in0;
} else { } else {
return in_vars[i]->Get<SelectedRows>(); return in_vars[i]->Get<SelectedRows>();
} }
......
...@@ -394,7 +394,7 @@ void MKLDNNDeviceContext::SetBlob(const std::string& name, ...@@ -394,7 +394,7 @@ void MKLDNNDeviceContext::SetBlob(const std::string& name,
int tid = platform::get_cur_thread_id(); int tid = platform::get_cur_thread_id();
std::lock_guard<std::mutex> lock(*p_mutex_.get()); std::lock_guard<std::mutex> lock(*p_mutex_);
// Find KeyBlob for current thread // Find KeyBlob for current thread
auto map_it = pMap->find(tid); auto map_it = pMap->find(tid);
...@@ -427,7 +427,7 @@ std::shared_ptr<void> MKLDNNDeviceContext::GetBlob( ...@@ -427,7 +427,7 @@ std::shared_ptr<void> MKLDNNDeviceContext::GetBlob(
int tid = platform::get_cur_thread_id(); int tid = platform::get_cur_thread_id();
std::lock_guard<std::mutex> lock(*p_mutex_.get()); std::lock_guard<std::mutex> lock(*p_mutex_);
// Find KeyBlob for current thread firstly // Find KeyBlob for current thread firstly
auto map_it = pMap->find(tid); auto map_it = pMap->find(tid);
......
...@@ -548,9 +548,8 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler { ...@@ -548,9 +548,8 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
PADDLE_ENFORCE((conv_p != nullptr) || (is_reusing_ == false), PADDLE_ENFORCE((conv_p != nullptr) || (is_reusing_ == false),
"Fail to find convolution primitive in device context"); "Fail to find convolution primitive in device context");
if (conv_p == nullptr) { if (conv_p == nullptr) {
conv_p = std::make_shared<forward_t>(*conv_pd_, *(src_memory_p), conv_p = std::make_shared<forward_t>(*conv_pd_, *src_memory_p,
*(weights_memory_p.get()), *weights_memory_p, *dst_memory_p);
*(dst_memory_p.get()));
dev_ctx_.SetBlob(prim_key, conv_p); dev_ctx_.SetBlob(prim_key, conv_p);
} else { } else {
...@@ -570,9 +569,9 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler { ...@@ -570,9 +569,9 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
PADDLE_ENFORCE((conv_p != nullptr) || (is_reusing_ == false), PADDLE_ENFORCE((conv_p != nullptr) || (is_reusing_ == false),
"Fail to find convolution primitive in device context"); "Fail to find convolution primitive in device context");
if (conv_p == nullptr) { if (conv_p == nullptr) {
conv_p = std::make_shared<forward_t>( conv_p = std::make_shared<forward_t>(*conv_pd_, *src_memory_p,
*conv_pd_, *(src_memory_p), *(weights_memory_p.get()), *weights_memory_p, *bias_memory_p,
*(bias_memory_p.get()), *(dst_memory_p.get())); *dst_memory_p);
dev_ctx_.SetBlob(prim_key, conv_p); dev_ctx_.SetBlob(prim_key, conv_p);
} else { } else {
......
...@@ -73,7 +73,7 @@ int main() { ...@@ -73,7 +73,7 @@ int main() {
PADDLE_ENFORCE_NE(loss_name, "", "loss not found"); PADDLE_ENFORCE_NE(loss_name, "", "loss not found");
// init all parameters // init all parameters
executor.Run(*startup_program.get(), &scope, 0); executor.Run(*startup_program, &scope, 0);
// prepare data // prepare data
auto x_var = scope.Var("x"); auto x_var = scope.Var("x");
...@@ -101,7 +101,7 @@ int main() { ...@@ -101,7 +101,7 @@ int main() {
clock_t t1 = clock(); clock_t t1 = clock();
for (int i = 0; i < 10; ++i) { for (int i = 0; i < 10; ++i) {
executor.Run(*train_program.get(), &scope, 0, false, true); executor.Run(*train_program, &scope, 0, false, true);
std::cout << "step: " << i << " loss: " std::cout << "step: " << i << " loss: "
<< loss_var->Get<paddle::framework::LoDTensor>().data<float>()[0] << loss_var->Get<paddle::framework::LoDTensor>().data<float>()[0]
<< std::endl; << std::endl;
......
...@@ -74,7 +74,7 @@ void Train() { ...@@ -74,7 +74,7 @@ void Train() {
float first_loss = 0.0; float first_loss = 0.0;
float last_loss = 0.0; float last_loss = 0.0;
for (int i = 0; i < 100; ++i) { for (int i = 0; i < 100; ++i) {
executor.Run(*train_program.get(), &scope, 0, false, true); executor.Run(*train_program, &scope, 0, false, true);
if (i == 0) { if (i == 0) {
first_loss = loss_var->Get<framework::LoDTensor>().data<float>()[0]; first_loss = loss_var->Get<framework::LoDTensor>().data<float>()[0];
} else if (i == 99) { } else if (i == 99) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册