From ca22586818c2ce9d9b4ac83f49a3c7a54570cc6b Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Fri, 24 Aug 2018 10:51:37 +0800 Subject: [PATCH] code optimize (cherry picked from commit 587cca7) --- paddle/fluid/operators/fill_constant_op.cc | 27 +++++++++++++++------ paddle/fluid/operators/uniform_random_op.cc | 2 +- paddle/fluid/operators/uniform_random_op.cu | 2 +- 3 files changed, 22 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/operators/fill_constant_op.cc b/paddle/fluid/operators/fill_constant_op.cc index 130f18dde4..2826b82117 100644 --- a/paddle/fluid/operators/fill_constant_op.cc +++ b/paddle/fluid/operators/fill_constant_op.cc @@ -15,7 +15,6 @@ limitations under the License. */ #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/math/math_function.h" -#include "paddle/fluid/platform/device_context.h" namespace paddle { namespace operators { @@ -41,19 +40,33 @@ class FillConstantOp : public framework::OperatorBase { static_cast(Attr("dtype")); auto value = Attr("value"); auto force_cpu = Attr("force_cpu"); - auto &out = - *scope.FindVar(Output("Out"))->GetMutable(); - out.Resize(framework::make_ddim(Attr>("shape"))); + + framework::Tensor *tensor = nullptr; + + auto &out_var = *scope.FindVar(Output("Out")); + + if (out_var.IsType()) { + tensor = out_var.GetMutable(); + tensor->Resize(framework::make_ddim(Attr>("shape"))); + } else if (out_var.IsType()) { + tensor = out_var.GetMutable()->mutable_value(); + tensor->Resize(framework::make_ddim(Attr>("shape"))); + } else { + PADDLE_THROW( + "fill constant op's output only" + "supports SelectedRows and LoDTensor"); + } + if (force_cpu) { auto cpu = platform::CPUPlace(); - out.mutable_data(cpu, framework::ToTypeIndex(data_type)); + tensor->mutable_data(cpu, framework::ToTypeIndex(data_type)); } else { - out.mutable_data(dev_place, framework::ToTypeIndex(data_type)); + tensor->mutable_data(dev_place, framework::ToTypeIndex(data_type)); } platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto &dev_ctx = *pool.Get(dev_place); - math::set_constant(dev_ctx, &out, value); + math::set_constant(dev_ctx, tensor, value); } }; diff --git a/paddle/fluid/operators/uniform_random_op.cc b/paddle/fluid/operators/uniform_random_op.cc index 5248767c2e..763bb40358 100644 --- a/paddle/fluid/operators/uniform_random_op.cc +++ b/paddle/fluid/operators/uniform_random_op.cc @@ -37,7 +37,7 @@ class CPUUniformRandomKernel : public framework::OpKernel { } else { PADDLE_THROW( "uniform_random_op's output only" - "supports SelectedRows and Tensor"); + "supports SelectedRows and LoDTensor"); } T* data = tensor->mutable_data(ctx.GetPlace()); unsigned int seed = static_cast(ctx.Attr("seed")); diff --git a/paddle/fluid/operators/uniform_random_op.cu b/paddle/fluid/operators/uniform_random_op.cu index e1c7323a30..bbb692b0dd 100644 --- a/paddle/fluid/operators/uniform_random_op.cu +++ b/paddle/fluid/operators/uniform_random_op.cu @@ -54,7 +54,7 @@ class GPUUniformRandomKernel : public framework::OpKernel { } else { PADDLE_THROW( "uniform_random_op's output only" - "supports SelectedRows and Tensor"); + "supports SelectedRows and LoDTensor"); } T* data = tensor->mutable_data(context.GetPlace()); unsigned int seed = static_cast(context.Attr("seed")); -- GitLab