提交 8eac2a46 编写于 作者: Y Yancey1989

update by comment

上级 9e9f5d80
......@@ -24,7 +24,7 @@ template <typename T>
class CPUUniformRandomKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
framework::Tensor* tensor(nullptr);
framework::Tensor* tensor = nullptr;
auto out_var = ctx.OutputVar("Out");
if (out_var->IsType<framework::LoDTensor>()) {
tensor = out_var->GetMutable<framework::LoDTensor>();
......@@ -33,7 +33,9 @@ class CPUUniformRandomKernel : public framework::OpKernel<T> {
tensor = out_var->GetMutable<framework::SelectedRows>()->mutable_value();
tensor->Resize(framework::make_ddim(shape));
} else {
PADDLE_THROW("Only support SelectedRows and Tensor");
PADDLE_THROW(
"uniform_random_op's output only"
"supports SelectedRows and Tensor");
}
T* data = tensor->mutable_data<T>(ctx.GetPlace());
unsigned int seed = static_cast<unsigned int>(ctx.Attr<int>("seed"));
......
......@@ -43,7 +43,7 @@ template <typename T>
class GPUUniformRandomKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
framework::Tensor* tensor(nullptr);
framework::Tensor* tensor = nullptr;
auto out_var = context.OutputVar("Out");
if (out_var->IsType<framework::LoDTensor>()) {
tensor = out_var->GetMutable<framework::LoDTensor>();
......@@ -52,7 +52,9 @@ class GPUUniformRandomKernel : public framework::OpKernel<T> {
tensor = out_var->GetMutable<framework::SelectedRows>()->mutable_value();
tensor->Resize(framework::make_ddim(shape));
} else {
PADDLE_THROW("Only support SelectedRows and Tensor");
PADDLE_THROW(
"uniform_random_op's output only"
"supports SelectedRows and Tensor");
}
T* data = tensor->mutable_data<T>(context.GetPlace());
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册