提交 ca225868 编写于 作者: T tangwei12

code optimize

(cherry picked from commit 587cca7e)
上级 557be6fc
...@@ -15,7 +15,6 @@ limitations under the License. */ ...@@ -15,7 +15,6 @@ limitations under the License. */
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -41,19 +40,33 @@ class FillConstantOp : public framework::OperatorBase { ...@@ -41,19 +40,33 @@ class FillConstantOp : public framework::OperatorBase {
static_cast<framework::proto::VarType::Type>(Attr<int>("dtype")); static_cast<framework::proto::VarType::Type>(Attr<int>("dtype"));
auto value = Attr<float>("value"); auto value = Attr<float>("value");
auto force_cpu = Attr<bool>("force_cpu"); auto force_cpu = Attr<bool>("force_cpu");
auto &out =
*scope.FindVar(Output("Out"))->GetMutable<framework::LoDTensor>(); framework::Tensor *tensor = nullptr;
out.Resize(framework::make_ddim(Attr<std::vector<int>>("shape")));
auto &out_var = *scope.FindVar(Output("Out"));
if (out_var.IsType<framework::LoDTensor>()) {
tensor = out_var.GetMutable<framework::LoDTensor>();
tensor->Resize(framework::make_ddim(Attr<std::vector<int>>("shape")));
} else if (out_var.IsType<framework::SelectedRows>()) {
tensor = out_var.GetMutable<framework::SelectedRows>()->mutable_value();
tensor->Resize(framework::make_ddim(Attr<std::vector<int>>("shape")));
} else {
PADDLE_THROW(
"fill constant op's output only"
"supports SelectedRows and LoDTensor");
}
if (force_cpu) { if (force_cpu) {
auto cpu = platform::CPUPlace(); auto cpu = platform::CPUPlace();
out.mutable_data(cpu, framework::ToTypeIndex(data_type)); tensor->mutable_data(cpu, framework::ToTypeIndex(data_type));
} else { } else {
out.mutable_data(dev_place, framework::ToTypeIndex(data_type)); tensor->mutable_data(dev_place, framework::ToTypeIndex(data_type));
} }
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(dev_place); auto &dev_ctx = *pool.Get(dev_place);
math::set_constant(dev_ctx, &out, value); math::set_constant(dev_ctx, tensor, value);
} }
}; };
......
...@@ -37,7 +37,7 @@ class CPUUniformRandomKernel : public framework::OpKernel<T> { ...@@ -37,7 +37,7 @@ class CPUUniformRandomKernel : public framework::OpKernel<T> {
} else { } else {
PADDLE_THROW( PADDLE_THROW(
"uniform_random_op's output only" "uniform_random_op's output only"
"supports SelectedRows and Tensor"); "supports SelectedRows and LoDTensor");
} }
T* data = tensor->mutable_data<T>(ctx.GetPlace()); T* data = tensor->mutable_data<T>(ctx.GetPlace());
unsigned int seed = static_cast<unsigned int>(ctx.Attr<int>("seed")); unsigned int seed = static_cast<unsigned int>(ctx.Attr<int>("seed"));
......
...@@ -54,7 +54,7 @@ class GPUUniformRandomKernel : public framework::OpKernel<T> { ...@@ -54,7 +54,7 @@ class GPUUniformRandomKernel : public framework::OpKernel<T> {
} else { } else {
PADDLE_THROW( PADDLE_THROW(
"uniform_random_op's output only" "uniform_random_op's output only"
"supports SelectedRows and Tensor"); "supports SelectedRows and LoDTensor");
} }
T* data = tensor->mutable_data<T>(context.GetPlace()); T* data = tensor->mutable_data<T>(context.GetPlace());
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed")); unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册