提交 fcd6f64b 编写于 作者: D dongzhihong

"redefine random op"

上级 58561d8f
......@@ -23,7 +23,7 @@ template <typename T>
class GaussianRandomOpKernel<platform::CPUPlace, T>
: public framework::OpKernel {
public:
void Compute(const framework::KernelContext& context) const override {
void Compute(const framework::ExecutionContext& context) const override {
auto mean = context.op_.GetAttr<T>("mean");
auto std = context.op_.GetAttr<T>("std");
auto* output = context.Output(0)->GetMutable<framework::Tensor>();
......@@ -41,15 +41,14 @@ class GaussianRandomOpKernel<platform::CPUPlace, T>
class GaussianRandomOp : public framework::OperatorWithKernel {
protected:
void InferShape(
const std::vector<const framework::Tensor*>& inputs,
const std::vector<framework::Tensor*>& outputs) const override {
void InferShape(const framework::InferShapeContext& ctx) const override {
PADDLE_ENFORCE(inputs.size() == 0, "Input size of RandomOp must be zero.");
PADDLE_ENFORCE(outputs.size() == 1, "Output size of RandomOp must be one.");
PADDLE_ENFORCE(outputs[0] != nullptr,
"Outputs of RandomOp must all be set.");
outputs[0]->Resize(
framework::make_ddim(this->GetAttr<std::vector<int>>("shape")));
auto* tensor = ctx.Output<Tensor>(0);
auto dims = GetAttr(std::vector<int>("shape"));
tensor->Resize(framework::make_ddim(dims));
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册