gaussian_random_op.cc 2.1 KB
Newer Older
D
dongzhihong 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

Q
qijun 已提交
15
#include "paddle/operators/gaussian_random_op.h"
D
dongzhihong 已提交
16 17 18

namespace paddle {
namespace operators {
D
dongzhihong 已提交
19

D
dongzhihong 已提交
20
class GaussianRandomOp : public framework::OperatorWithKernel {
Y
Yu Yang 已提交
21 22
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;
23

D
dongzhihong 已提交
24
 protected:
25 26 27 28 29
  void InferShape(const framework::InferShapeContext& context) const override {
    auto* tensor = context.Output<framework::Tensor>(0);
    auto dims = GetAttr<std::vector<int>>("dims");
    PADDLE_ENFORCE(dims.size() > 0UL,
                   "dims can be one int or array. dims must be set.");
D
dongzhihong 已提交
30
    tensor->Resize(framework::make_ddim(dims));
D
dongzhihong 已提交
31 32 33
  }
};

D
dongzhihong 已提交
34
class GaussianRandomOpMaker : public framework::OpProtoAndCheckerMaker {
D
dongzhihong 已提交
35
 public:
D
dongzhihong 已提交
36 37
  GaussianRandomOpMaker(framework::OpProto* proto,
                        framework::OpAttrChecker* op_checker)
D
dongzhihong 已提交
38 39 40
      : framework::OpProtoAndCheckerMaker(proto, op_checker) {
    AddOutput("Out", "output matrix of random op");
    AddComment(R"DOC(
41 42
GaussianRandom operator.
Use to initialize tensor with gaussian random generator.
D
dongzhihong 已提交
43
)DOC");
44 45 46 47

    AddAttr<std::vector<int>>("dims", "The dimension of random tensor.");
    AddAttr<float>("mean", "mean value of random.").SetDefault(.0f);
    AddAttr<float>("std", "minimum value of random value.").SetDefault(1.0f);
D
dongzhihong 已提交
48 49 50 51 52 53
  }
};

}  // namespace operators
}  // namespace paddle

54
namespace ops = paddle::operators;
F
fengjiayi 已提交
55 56
REGISTER_OP_WITHOUT_GRADIENT(gaussian_random, ops::GaussianRandomOp,
                             ops::GaussianRandomOpMaker);
Q
qijun 已提交
57 58 59
REGISTER_OP_CPU_KERNEL(
    gaussian_random,
    ops::GaussianRandomKernel<paddle::platform::CPUPlace, float>);