提交 8eac2a46 编写于 作者: Y Yancey1989

update by comment

上级 9e9f5d80
...@@ -24,7 +24,7 @@ template <typename T> ...@@ -24,7 +24,7 @@ template <typename T>
class CPUUniformRandomKernel : public framework::OpKernel<T> { class CPUUniformRandomKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
framework::Tensor* tensor(nullptr); framework::Tensor* tensor = nullptr;
auto out_var = ctx.OutputVar("Out"); auto out_var = ctx.OutputVar("Out");
if (out_var->IsType<framework::LoDTensor>()) { if (out_var->IsType<framework::LoDTensor>()) {
tensor = out_var->GetMutable<framework::LoDTensor>(); tensor = out_var->GetMutable<framework::LoDTensor>();
...@@ -33,7 +33,9 @@ class CPUUniformRandomKernel : public framework::OpKernel<T> { ...@@ -33,7 +33,9 @@ class CPUUniformRandomKernel : public framework::OpKernel<T> {
tensor = out_var->GetMutable<framework::SelectedRows>()->mutable_value(); tensor = out_var->GetMutable<framework::SelectedRows>()->mutable_value();
tensor->Resize(framework::make_ddim(shape)); tensor->Resize(framework::make_ddim(shape));
} else { } else {
PADDLE_THROW("Only support SelectedRows and Tensor"); PADDLE_THROW(
"uniform_random_op's output only"
"supports SelectedRows and Tensor");
} }
T* data = tensor->mutable_data<T>(ctx.GetPlace()); T* data = tensor->mutable_data<T>(ctx.GetPlace());
unsigned int seed = static_cast<unsigned int>(ctx.Attr<int>("seed")); unsigned int seed = static_cast<unsigned int>(ctx.Attr<int>("seed"));
......
...@@ -43,7 +43,7 @@ template <typename T> ...@@ -43,7 +43,7 @@ template <typename T>
class GPUUniformRandomKernel : public framework::OpKernel<T> { class GPUUniformRandomKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
framework::Tensor* tensor(nullptr); framework::Tensor* tensor = nullptr;
auto out_var = context.OutputVar("Out"); auto out_var = context.OutputVar("Out");
if (out_var->IsType<framework::LoDTensor>()) { if (out_var->IsType<framework::LoDTensor>()) {
tensor = out_var->GetMutable<framework::LoDTensor>(); tensor = out_var->GetMutable<framework::LoDTensor>();
...@@ -52,7 +52,9 @@ class GPUUniformRandomKernel : public framework::OpKernel<T> { ...@@ -52,7 +52,9 @@ class GPUUniformRandomKernel : public framework::OpKernel<T> {
tensor = out_var->GetMutable<framework::SelectedRows>()->mutable_value(); tensor = out_var->GetMutable<framework::SelectedRows>()->mutable_value();
tensor->Resize(framework::make_ddim(shape)); tensor->Resize(framework::make_ddim(shape));
} else { } else {
PADDLE_THROW("Only support SelectedRows and Tensor"); PADDLE_THROW(
"uniform_random_op's output only"
"supports SelectedRows and Tensor");
} }
T* data = tensor->mutable_data<T>(context.GetPlace()); T* data = tensor->mutable_data<T>(context.GetPlace());
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed")); unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册