提交 fcd6f64b 编写于 作者: D dongzhihong

"redefine random op"

上级 58561d8f
...@@ -23,7 +23,7 @@ template <typename T> ...@@ -23,7 +23,7 @@ template <typename T>
class GaussianRandomOpKernel<platform::CPUPlace, T> class GaussianRandomOpKernel<platform::CPUPlace, T>
: public framework::OpKernel { : public framework::OpKernel {
public: public:
void Compute(const framework::KernelContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto mean = context.op_.GetAttr<T>("mean"); auto mean = context.op_.GetAttr<T>("mean");
auto std = context.op_.GetAttr<T>("std"); auto std = context.op_.GetAttr<T>("std");
auto* output = context.Output(0)->GetMutable<framework::Tensor>(); auto* output = context.Output(0)->GetMutable<framework::Tensor>();
...@@ -41,15 +41,14 @@ class GaussianRandomOpKernel<platform::CPUPlace, T> ...@@ -41,15 +41,14 @@ class GaussianRandomOpKernel<platform::CPUPlace, T>
class GaussianRandomOp : public framework::OperatorWithKernel { class GaussianRandomOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape( void InferShape(const framework::InferShapeContext& ctx) const override {
const std::vector<const framework::Tensor*>& inputs,
const std::vector<framework::Tensor*>& outputs) const override {
PADDLE_ENFORCE(inputs.size() == 0, "Input size of RandomOp must be zero."); PADDLE_ENFORCE(inputs.size() == 0, "Input size of RandomOp must be zero.");
PADDLE_ENFORCE(outputs.size() == 1, "Output size of RandomOp must be one."); PADDLE_ENFORCE(outputs.size() == 1, "Output size of RandomOp must be one.");
PADDLE_ENFORCE(outputs[0] != nullptr, PADDLE_ENFORCE(outputs[0] != nullptr,
"Outputs of RandomOp must all be set."); "Outputs of RandomOp must all be set.");
outputs[0]->Resize( auto* tensor = ctx.Output<Tensor>(0);
framework::make_ddim(this->GetAttr<std::vector<int>>("shape"))); auto dims = GetAttr(std::vector<int>("shape"));
tensor->Resize(framework::make_ddim(dims));
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册