提交 6fcdc916 编写于 作者: Q qiaolongfei

add op() to InferShapeContext

上级 fc8a1afa
...@@ -229,6 +229,10 @@ class InferShapeContext { ...@@ -229,6 +229,10 @@ class InferShapeContext {
InferShapeContext(const OperatorBase& op, const Scope& scope) InferShapeContext(const OperatorBase& op, const Scope& scope)
: op_(op), scope_(scope) {} : op_(op), scope_(scope) {}
const OperatorBase& op() const { return op_; }
const Scope& scope() const { return scope_; }
size_t InputSize(const std::string& name) const { size_t InputSize(const std::string& name) const {
return op_.Inputs(name).size(); return op_.Inputs(name).size();
} }
...@@ -312,6 +316,7 @@ class InferShapeContext { ...@@ -312,6 +316,7 @@ class InferShapeContext {
return res; return res;
} }
private:
const OperatorBase& op_; const OperatorBase& op_;
const Scope& scope_; const Scope& scope_;
}; };
......
...@@ -122,10 +122,10 @@ class CPUKernelTest : public OpKernel { ...@@ -122,10 +122,10 @@ class CPUKernelTest : public OpKernel {
public: public:
void Compute(const ExecutionContext& ctx) const { void Compute(const ExecutionContext& ctx) const {
std::cout << "this is cpu kernel" << std::endl; std::cout << "this is cpu kernel" << std::endl;
std::cout << ctx.op_.DebugString() << std::endl; std::cout << ctx.op().DebugString() << std::endl;
cpu_kernel_run_num++; cpu_kernel_run_num++;
ASSERT_EQ(ctx.op_.Input("x"), "IN1"); ASSERT_EQ(ctx.op().Input("x"), "IN1");
ASSERT_EQ(ctx.op_.Output("y"), "OUT1"); ASSERT_EQ(ctx.op().Output("y"), "OUT1");
} }
}; };
...@@ -148,7 +148,7 @@ class OpKernelTestMultiInputsProtoAndCheckerMaker ...@@ -148,7 +148,7 @@ class OpKernelTestMultiInputsProtoAndCheckerMaker
class CPUKernalMultiInputsTest : public OpKernel { class CPUKernalMultiInputsTest : public OpKernel {
public: public:
void Compute(const ExecutionContext& ctx) const { void Compute(const ExecutionContext& ctx) const {
auto xs = ctx.op_.Inputs("xs"); auto xs = ctx.op().Inputs("xs");
ASSERT_EQ(xs.size(), 3UL); ASSERT_EQ(xs.size(), 3UL);
ASSERT_EQ(xs[0], "x0"); ASSERT_EQ(xs[0], "x0");
ASSERT_EQ(xs[1], "x1"); ASSERT_EQ(xs[1], "x1");
...@@ -172,10 +172,10 @@ class CPUKernalMultiInputsTest : public OpKernel { ...@@ -172,10 +172,10 @@ class CPUKernalMultiInputsTest : public OpKernel {
auto outTensor0 = ctx.MultiOutput<Tensor>("ys"); auto outTensor0 = ctx.MultiOutput<Tensor>("ys");
ASSERT_EQ(outTensor0.size(), 2U); ASSERT_EQ(outTensor0.size(), 2U);
auto k = ctx.op_.Input("k"); auto k = ctx.op().Input("k");
ASSERT_EQ(k, "k0"); ASSERT_EQ(k, "k0");
auto ys = ctx.op_.Outputs("ys"); auto ys = ctx.op().Outputs("ys");
ASSERT_EQ(ys.size(), 2UL); ASSERT_EQ(ys.size(), 2UL);
ASSERT_EQ(ys[0], "y0"); ASSERT_EQ(ys[0], "y0");
ASSERT_EQ(ys[1], "y1"); ASSERT_EQ(ys[1], "y1");
......
...@@ -19,13 +19,13 @@ template <typename T> ...@@ -19,13 +19,13 @@ template <typename T>
class CPUGaussianRandomKernel : public framework::OpKernel { class CPUGaussianRandomKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
float mean = context.op_.GetAttr<float>("mean"); float mean = context.op().GetAttr<float>("mean");
float std = context.op_.GetAttr<float>("std"); float std = context.op().GetAttr<float>("std");
auto* tensor = context.Output<framework::Tensor>("Out"); auto* tensor = context.Output<framework::Tensor>("Out");
T* data = tensor->mutable_data<T>(context.GetPlace()); T* data = tensor->mutable_data<T>(context.GetPlace());
unsigned int seed = unsigned int seed =
static_cast<unsigned int>(context.op_.GetAttr<int>("seed")); static_cast<unsigned int>(context.op().GetAttr<int>("seed"));
std::minstd_rand engine; std::minstd_rand engine;
if (seed == 0) { if (seed == 0) {
seed = std::random_device()(); seed = std::random_device()();
......
...@@ -29,10 +29,10 @@ class MulOp : public framework::OperatorWithKernel { ...@@ -29,10 +29,10 @@ class MulOp : public framework::OperatorWithKernel {
auto dim1 = ctx.Input<Tensor>("Y")->dims(); auto dim1 = ctx.Input<Tensor>("Y")->dims();
PADDLE_ENFORCE_EQ(dim0.size(), 2, PADDLE_ENFORCE_EQ(dim0.size(), 2,
"input X(%s) should be a tensor with 2 dims, a matrix", "input X(%s) should be a tensor with 2 dims, a matrix",
ctx.op_.Input("X")); ctx.op().Input("X"));
PADDLE_ENFORCE_EQ(dim1.size(), 2, PADDLE_ENFORCE_EQ(dim1.size(), 2,
"input Y(%s) should be a tensor with 2 dims, a matrix", "input Y(%s) should be a tensor with 2 dims, a matrix",
ctx.op_.Input("Y")); ctx.op().Input("Y"));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
dim0[1], dim1[0], dim0[1], dim1[0],
"First matrix's width must be equal with second matrix's height."); "First matrix's width must be equal with second matrix's height.");
......
...@@ -27,7 +27,7 @@ class ScaleKernel : public framework::OpKernel { ...@@ -27,7 +27,7 @@ class ScaleKernel : public framework::OpKernel {
auto* in = context.Input<framework::Tensor>("X"); auto* in = context.Input<framework::Tensor>("X");
tensor->mutable_data<T>(in->place()); tensor->mutable_data<T>(in->place());
auto scale = static_cast<T>(context.op_.GetAttr<AttrType>("scale")); auto scale = static_cast<T>(context.op().GetAttr<AttrType>("scale"));
auto eigen_out = framework::EigenVector<T>::Flatten(*tensor); auto eigen_out = framework::EigenVector<T>::Flatten(*tensor);
auto eigen_in = framework::EigenVector<T>::Flatten(*in); auto eigen_in = framework::EigenVector<T>::Flatten(*in);
......
...@@ -31,7 +31,7 @@ class SGDOpKernel : public framework::OpKernel { ...@@ -31,7 +31,7 @@ class SGDOpKernel : public framework::OpKernel {
auto param = ctx.Input<Tensor>("param"); auto param = ctx.Input<Tensor>("param");
auto grad = ctx.Input<Tensor>("grad"); auto grad = ctx.Input<Tensor>("grad");
auto param_out = ctx.Output<Tensor>("param_out"); auto param_out = ctx.Output<Tensor>("param_out");
float lr = ctx.op_.GetAttr<float>("learning_rate"); float lr = ctx.op().GetAttr<float>("learning_rate");
param_out->mutable_data<T>(ctx.GetPlace()); param_out->mutable_data<T>(ctx.GetPlace());
......
...@@ -27,15 +27,15 @@ class CPUUniformRandomKernel : public framework::OpKernel { ...@@ -27,15 +27,15 @@ class CPUUniformRandomKernel : public framework::OpKernel {
auto* tensor = context.Output<framework::Tensor>("Out"); auto* tensor = context.Output<framework::Tensor>("Out");
T* data = tensor->mutable_data<T>(context.GetPlace()); T* data = tensor->mutable_data<T>(context.GetPlace());
unsigned int seed = unsigned int seed =
static_cast<unsigned int>(context.op_.GetAttr<int>("seed")); static_cast<unsigned int>(context.op().GetAttr<int>("seed"));
std::minstd_rand engine; std::minstd_rand engine;
if (seed == 0) { if (seed == 0) {
seed = std::random_device()(); seed = std::random_device()();
} }
engine.seed(seed); engine.seed(seed);
std::uniform_real_distribution<T> dist( std::uniform_real_distribution<T> dist(
static_cast<T>(context.op_.GetAttr<float>("min")), static_cast<T>(context.op().GetAttr<float>("min")),
static_cast<T>(context.op_.GetAttr<float>("max"))); static_cast<T>(context.op().GetAttr<float>("max")));
ssize_t size = framework::product(tensor->dims()); ssize_t size = framework::product(tensor->dims());
for (ssize_t i = 0; i < size; ++i) { for (ssize_t i = 0; i < size; ++i) {
data[i] = dist(engine); data[i] = dist(engine);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册