diff --git a/paddle/framework/attribute.h b/paddle/framework/attribute.h index 488fa38faf12ee51087643f79295f36bfd33ee22..c7559cefb6415ee141f32e4357459653564cd2ac 100644 --- a/paddle/framework/attribute.h +++ b/paddle/framework/attribute.h @@ -45,6 +45,21 @@ inline AttrType AttrTypeID() { Attribute GetAttrValue(const OpDesc::Attr& attr_desc); +class AttrReader { + public: + explicit AttrReader(const AttributeMap& attrs) : attrs_(attrs) {} + + template + inline const T& Get(const std::string& name) const { + PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should be in AttributeMap", + name); + return boost::get(attrs_.at(name)); + } + + private: + const AttributeMap& attrs_; +}; + // check whether a value(attribute) fit a certain limit template class GreaterThanChecker { diff --git a/paddle/framework/op_registry_test.cc b/paddle/framework/op_registry_test.cc index e00c6e8d904508ec9985537fc703c7c61a14e0de..b8fdf69683e645d991cf8dc2297b486680445a00 100644 --- a/paddle/framework/op_registry_test.cc +++ b/paddle/framework/op_registry_test.cc @@ -174,4 +174,4 @@ TEST(OpRegistry, CustomChecker) { op->Run(scope, dev_ctx); int test_attr = op->Attr("test_attr"); ASSERT_EQ(test_attr, 4); -} \ No newline at end of file +} diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index fcbfc3e4377edd0ea55c8d4328c325fa18663001..a3f28339aa64c6bde3fcefdae8b0973a5bbdd585 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -14,7 +14,6 @@ limitations under the License. */ #include "paddle/framework/operator.h" #include -#include "paddle/framework/op_registry.h" namespace paddle { namespace framework { @@ -33,6 +32,24 @@ ExecutionContext::GetEigenDevice() const { } #endif +const Tensor* GetTensorFromVar(const Variable* var) { + if (var->IsType()) { + return &var->Get(); + } + PADDLE_ENFORCE(var->IsType(), + "The Input must be LoDTensor or Tensor."); + return &var->Get(); +} + +Tensor* GetTensorFromVar(Variable* var) { + if (var->IsType()) { + return var->GetMutable(); + } + PADDLE_ENFORCE(var->IsType(), + "The Input must be LoDTensor or Tensor."); + return var->GetMutable(); +} + std::string OperatorBase::Input(const std::string& name) const { auto& ins = Inputs(name); PADDLE_ENFORCE_LE(ins.size(), 1UL, diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 2d6d5510ef6dc83f1a016be6ff123f0b9bcaf230..77c7c855c0ffed5032e639237b01037a990652c4 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -24,6 +24,7 @@ limitations under the License. */ #include "paddle/framework/framework.pb.h" #include "paddle/framework/lod_tensor.h" #include "paddle/framework/scope.h" +#include "paddle/framework/shape_inference.h" #include "paddle/framework/tensor.h" #include "paddle/platform/device_context.h" #include "paddle/platform/place.h" @@ -56,6 +57,9 @@ class OperatorBase; class InferShapeContext; class ExecutionContext; +extern const Tensor* GetTensorFromVar(const Variable* var); +extern Tensor* GetTensorFromVar(Variable* var); + /** * OperatorBase has the basic element that Net will call to do computation. * Only CreateOperator from OpRegistry will new Operator directly. User @@ -262,15 +266,6 @@ class InferShapeContext { return res; } - const Tensor* GetTensorFromVar(const Variable* var) const { - if (var->IsType()) { - return &var->Get(); - } - PADDLE_ENFORCE(var->IsType(), - "The Input(%s) must be LoDTensor or Tensor."); - return &var->Get(); - } - void ShareLoD(const std::string& in, const std::string& out, size_t i = 0, size_t j = 0) const { PADDLE_ENFORCE_LT(i, InputSize(in)); @@ -340,6 +335,78 @@ class ExecutionContext : public InferShapeContext { const platform::DeviceContext& device_context_; }; +class RuntimeInferShapeContext : public InferShapeContextBase { + public: + RuntimeInferShapeContext(const OperatorBase& op, const Scope& scope) + : op_(op), scope_(scope) {} + + bool HasInput(const std::string& name) const { + auto ipt = op_.Input(name); + auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt); + return var != nullptr; + } + + bool HasOutput(const std::string& name) const { + auto ipt = op_.Output(name); + auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt); + return var != nullptr; + } + + DDim GetInputDim(const std::string& name) const { + return GetDim(op_.Input(name)); + } + + void SetInputDim(const std::string& name, const DDim& dim) { + SetDim(op_.Input(name), dim); + } + + DDim GetOutputDim(const std::string& name) const { + return GetDim(op_.Output(name)); + } + + void SetOutputDim(const std::string& name, const DDim& dim) { + SetDim(op_.Output(name), dim); + } + + AttrReader Attrs() const { return AttrReader(op_.Attrs()); } + + const std::vector& Inputs(const std::string& name) const { + return op_.Inputs(name); + } + + const std::vector& Outputs(const std::string& name) const { + return op_.Outputs(name); + } + + private: + template + Tensor* GetTensor(const std::string& name) const { + Tensor* t = nullptr; + auto* var = scope_.FindVar(name); + if (!var->IsType() && !var->IsType()) { + if (Allocate) { + t = var->GetMutable(); + } else { + PADDLE_THROW("Variable(%s) should be tensor", name); + } + } else { + t = GetTensorFromVar(scope_.FindVar(name)); + } + return t; + } + + DDim GetDim(const std::string& name) const { + return GetTensor(name)->dims(); + } + + void SetDim(const std::string& name, const DDim& dim) { + GetTensor(name)->Resize(dim); + } + + const OperatorBase& op_; + const Scope& scope_; +}; + class OpKernel { public: /** @@ -383,8 +450,10 @@ class OperatorWithKernel : public OperatorBase { const VariableNameMap& outputs, const AttributeMap& attrs) : OperatorBase(type, inputs, outputs, attrs) {} + // runtime infershape void InferShape(const Scope& scope) const override { - InferShape(InferShapeContext(*this, scope)); + auto c = RuntimeInferShapeContext(*this, scope); + InferShape(&c); } void Run(const Scope& scope, @@ -406,7 +475,7 @@ class OperatorWithKernel : public OperatorBase { } protected: - virtual void InferShape(const InferShapeContext& ctx) const = 0; + virtual void InferShape(InferShapeContextBase* ctx) const = 0; }; } // namespace framework diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index 0beab0fac5b94c78121261d2661a6f969289afc4..8b4bb01a7bb80eaccee40f14fa82617505b1e2e5 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/framework/operator.h" #include "gtest/gtest.h" +#include "paddle/framework/op_info.h" #include "paddle/framework/op_registry.h" namespace paddle { @@ -114,7 +115,7 @@ class OpWithKernelTest : public OperatorWithKernel { using OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext& ctx) const override {} + void InferShape(framework::InferShapeContextBase* ctx) const override {} }; template diff --git a/paddle/framework/shape_inference.h b/paddle/framework/shape_inference.h new file mode 100644 index 0000000000000000000000000000000000000000..b07fc788124413f728c713027609d9d2d1c39538 --- /dev/null +++ b/paddle/framework/shape_inference.h @@ -0,0 +1,82 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/ddim.h" + +namespace paddle { +namespace framework { + +class InferShapeContextBase { + public: + virtual ~InferShapeContextBase() {} + virtual bool HasInput(const std::string &name) const = 0; + virtual bool HasOutput(const std::string &name) const = 0; + virtual framework::DDim GetInputDim(const std::string &name) const = 0; + std::vector GetInputsDim(const std::string &name) const { + const std::vector &names = Inputs(name); + return GetDims(names); + } + virtual void SetInputDim(const std::string &name, + const framework::DDim &dim) = 0; + void SetInputsDim(const std::string &name, + const std::vector &dims) { + auto &names = Inputs(name); + SetDims(names, dims); + } + virtual framework::DDim GetOutputDim(const std::string &name) const = 0; + std::vector GetOutputsDim(const std::string &name) const { + const std::vector &names = Outputs(name); + return GetDims(names); + } + virtual void SetOutputDim(const std::string &name, const DDim &dim) = 0; + void SetOutputsDim(const std::string &name, + const std::vector &dims) { + auto &names = Outputs(name); + SetDims(names, dims); + } + virtual AttrReader Attrs() const = 0; + virtual const std::vector &Inputs( + const std::string &name) const = 0; + virtual const std::vector &Outputs( + const std::string &name) const = 0; + // TODO(qiao) implement this function + void ShareLoD(const std::string &in, const std::string &out, size_t i = 0, + size_t j = 0) const {} + + protected: + virtual framework::DDim GetDim(const std::string &name) const = 0; + virtual void SetDim(const std::string &name, const framework::DDim &dim) = 0; + std::vector GetDims( + const std::vector &names) const { + std::vector ret; + ret.reserve(names.size()); + std::transform( + names.begin(), names.end(), std::back_inserter(ret), + [this](const std::string &name) { return this->GetDim(name); }); + return ret; + } + void SetDims(const std::vector &names, + const std::vector &dims) { + size_t length = names.size(); + PADDLE_ENFORCE_EQ(length, dims.size()); + for (size_t i = 0; i < length; ++i) { + SetDim(names[i], dims[i]); + } + } +}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/operators/accuracy_op.cc b/paddle/operators/accuracy_op.cc index 70e4f9da1221ab300e2b507a3da2f7c5da93f2e4..82010bfb53e58a0836c99c353590f4e32e25ac4a 100644 --- a/paddle/operators/accuracy_op.cc +++ b/paddle/operators/accuracy_op.cc @@ -22,25 +22,23 @@ class AccuracyOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL( - ctx.InputVar("Inference"), - "Input(Inference) of AccuracyOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), - "Input(Label) of AccuracyOp should not be null."); - PADDLE_ENFORCE_NOT_NULL( - ctx.OutputVar("Accuracy"), - "Output(Accuracy) of AccuracyOp should not be null."); + void InferShape(framework::InferShapeContextBase *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Inference"), + "Input(Inference) of AccuracyOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Label"), + "Input(Label) of AccuracyOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Accuracy"), + "Output(Accuracy) of AccuracyOp should not be null."); - auto *inference = ctx.Input("Inference"); - auto *label = ctx.Input("Label"); + auto inference_dim = ctx->GetInputDim("Inference"); + auto label_dim = ctx->GetInputDim("Label"); - PADDLE_ENFORCE_EQ(label->dims().size(), 1, "label must be a vector"); - PADDLE_ENFORCE_EQ(inference->dims()[0], label->dims()[0], + PADDLE_ENFORCE_EQ(label_dim.size(), 1, "label must be a vector"); + PADDLE_ENFORCE_EQ(inference_dim[0], label_dim[0], "inference size must be the same as label size"); - ctx.Output("Accuracy")->Resize({1}); - ctx.ShareLoD("Inference", /*->*/ "Accuracy"); + ctx->SetOutputDim("Accuracy", {1}); + ctx->ShareLoD("Inference", /*->*/ "Accuracy"); } }; diff --git a/paddle/operators/activation_op.cc b/paddle/operators/activation_op.cc index 06654702bc42cc7cf4917b00693334b1d36ce371..f77e1c572e33533ac672e3d476a7e6dad122031f 100644 --- a/paddle/operators/activation_op.cc +++ b/paddle/operators/activation_op.cc @@ -22,10 +22,9 @@ class ActivationOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - ctx.Output("Y")->Resize( - ctx.Input("X")->dims()); - ctx.ShareLoD("X", /*->*/ "Y"); + void InferShape(framework::InferShapeContextBase *ctx) const override { + ctx->SetOutputDim("Y", ctx->GetInputDim("X")); + ctx->ShareLoD("X", /*->*/ "Y"); } }; @@ -34,9 +33,8 @@ class ActivationOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - ctx.Output(framework::GradVarName("X")) - ->Resize(ctx.Input("Y")->dims()); + void InferShape(framework::InferShapeContextBase *ctx) const override { + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Y")); } }; diff --git a/paddle/operators/add_op.cc b/paddle/operators/add_op.cc index ed11d096974341022637676537793645f46738f0..3914d1323083ede6a7ea07e7b4ef76b9e4afd26d 100644 --- a/paddle/operators/add_op.cc +++ b/paddle/operators/add_op.cc @@ -22,25 +22,23 @@ class AddOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), - "Input(X) of AddOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), - "Input(Y) of AddOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), - "Output(Out) of AddOp should not be null."); + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of AddOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) of AddOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of AddOp should not be null."); - PADDLE_ENFORCE_EQ(ctx.Input("X")->dims(), - ctx.Input("Y")->dims(), + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); + PADDLE_ENFORCE_EQ(x_dims, y_dims, "Two input of Add Op's dimension must be same."); - ctx.Output("Out")->Resize( - ctx.Input("X")->dims()); + ctx->SetOutputDim("Out", x_dims); } }; class AddOpMaker : public framework::OpProtoAndCheckerMaker { public: - AddOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + AddOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "The first input of add op"); AddInput("Y", "The second input of add op"); @@ -58,7 +56,7 @@ class AddOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override {} + void InferShape(framework::InferShapeContextBase* ctx) const override {} }; } // namespace operators diff --git a/paddle/operators/clip_op.cc b/paddle/operators/clip_op.cc index e5a54bc4b226fd24337050fdd84b2de9c49f7949..316d28f174658de0e20ed9512f315da305bbb6d0 100644 --- a/paddle/operators/clip_op.cc +++ b/paddle/operators/clip_op.cc @@ -22,24 +22,24 @@ class ClipOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), - "Input(X) of ClipOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), - "Output(Out) of ClipOp should not be null."); - auto x_dims = ctx.Input("X")->dims(); - auto max = Attr("max"); - auto min = Attr("min"); + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of ClipOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of ClipOp should not be null."); + auto x_dims = ctx->GetInputDim("X"); + auto max = ctx->Attrs().Get("max"); + auto min = ctx->Attrs().Get("min"); PADDLE_ENFORCE_LT(min, max, "max should be greater than min."); - ctx.Output("Out")->Resize(x_dims); - ctx.ShareLoD("X", /*->*/ "Out"); + ctx->SetOutputDim("Out", x_dims); + ctx->ShareLoD("X", /*->*/ "Out"); } }; template class ClipOpMaker : public framework::OpProtoAndCheckerMaker { public: - ClipOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + ClipOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "(Tensor)The input of clip op." @@ -61,14 +61,13 @@ class ClipOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null"); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), - "Input(Out@GRAD) should not be null"); - auto x_dims = ctx.Input("X")->dims(); - auto *x_grad = ctx.Output(framework::GradVarName("X")); - if (x_grad != nullptr) { - x_grad->Resize(x_dims); + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null"); + auto x_dims = ctx->GetInputDim("X"); + if (ctx->HasOutput(framework::GradVarName("X"))) { + ctx->SetOutputDim(framework::GradVarName("X"), x_dims); } } }; diff --git a/paddle/operators/concat_op.cc b/paddle/operators/concat_op.cc index 07f847079e834716904dcc038d2097efd268bd3e..01cbfc33efcb4042438fbb398fbcca9457f1334f 100644 --- a/paddle/operators/concat_op.cc +++ b/paddle/operators/concat_op.cc @@ -24,31 +24,30 @@ class ConcatOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), - "Output(Out) of ConcatOp should not be null."); + void InferShape(framework::InferShapeContextBase *ctx) const override { + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of ConcatOp should not be null."); - auto ins = ctx.MultiInput("X"); - auto *out = ctx.Output("Out"); - size_t axis = static_cast(ctx.Attr("axis")); + auto ins = ctx->GetInputsDim("X"); + size_t axis = static_cast(ctx->Attrs().Get("axis")); size_t n = ins.size(); PADDLE_ENFORCE_GT(n, 1, "Input tensors count should > 1."); - auto out_dims = ins[0]->dims(); + auto out_dims = ins[0]; size_t in_zero_dims_size = out_dims.size(); for (size_t i = 1; i < n; i++) { for (size_t j = 0; j < in_zero_dims_size; j++) { if (j == axis) { - out_dims[axis] += ins[i]->dims()[j]; + out_dims[axis] += ins[i][j]; continue; } - PADDLE_ENFORCE_EQ(out_dims[j], ins[i]->dims()[j], + PADDLE_ENFORCE_EQ(out_dims[j], ins[i][j], "Input tensors should have the same " "elements except the specify axis.") } } - out->Resize(out_dims); + ctx->SetOutputDim("Out", out_dims); } }; diff --git a/paddle/operators/cond_op.cc b/paddle/operators/cond_op.cc index 8262a7a5c8c13c86c5f6c123a14fa89696358c57..1d44782b210bc0c40fd68dba29a16fa6959d6076 100644 --- a/paddle/operators/cond_op.cc +++ b/paddle/operators/cond_op.cc @@ -215,7 +215,7 @@ class CondOpProtoAndCheckerMaker : public framework::OpProtoAndCheckerMaker { AddComment(R"DOC( Sample dependent Cond Operator: Given Cond[i] as a 1/0 vector to indicate true/false -The equation is: +The equation is: Out[i] = subnet_t[i], if Cond[i] == true Out[i] = subnet_t[i], if Cond[i] == false )DOC"); diff --git a/paddle/operators/conv2d_op.cc b/paddle/operators/conv2d_op.cc index c3281db0964de6d7dd6be629fbcc55cabb9fef9d..5cc82944bb6b9a4fc5cd94cf2233ab84fc105fe7 100644 --- a/paddle/operators/conv2d_op.cc +++ b/paddle/operators/conv2d_op.cc @@ -27,27 +27,25 @@ class Conv2DOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Input"), - "Input(Input) of Conv2DOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Filter"), - "Input(Filter) of Conv2DOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Output"), - "Output(Output) of Conv2DOp should not be null."); - - auto in = ctx.Input("Input"); - auto filter = ctx.Input("Filter"); - auto out = ctx.Output("Output"); - std::vector strides = Attr>("strides"); - std::vector paddings = Attr>("paddings"); - int groups = Attr("groups"); - int input_channels = in->dims()[1]; - int output_channels = filter->dims()[0]; - - PADDLE_ENFORCE_EQ(in->dims().size(), 4, "Conv2DOp input should be 4-D."); - PADDLE_ENFORCE_EQ(filter->dims().size(), 4, - "Conv2DOp filter should be 4-D."); - PADDLE_ENFORCE_EQ(input_channels, filter->dims()[1] * groups, + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Input"), + "Input(Input) of Conv2DOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Filter"), + "Input(Filter) of Conv2DOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Output"), + "Output(Output) of Conv2DOp should not be null."); + + auto in_dims = ctx->GetInputDim("Input"); + auto filter_dims = ctx->GetInputDim("Filter"); + std::vector strides = ctx->Attrs().Get>("strides"); + std::vector paddings = ctx->Attrs().Get>("paddings"); + int groups = ctx->Attrs().Get("groups"); + int input_channels = in_dims[1]; + int output_channels = filter_dims[0]; + + PADDLE_ENFORCE_EQ(in_dims.size(), 4, "Conv2DOp input should be 4-D."); + PADDLE_ENFORCE_EQ(filter_dims.size(), 4, "Conv2DOp filter should be 4-D."); + PADDLE_ENFORCE_EQ(input_channels, filter_dims[1] * groups, "The number of input channels should be equal to filter " "channels * groups."); PADDLE_ENFORCE_EQ( @@ -55,17 +53,17 @@ class Conv2DOp : public framework::OperatorWithKernel { "The number of output channels should be divided by groups."); auto output_height = - outputSize(in->dims()[2], filter->dims()[2], paddings[0], strides[0]); + outputSize(in_dims[2], filter_dims[2], paddings[0], strides[0]); auto output_width = - outputSize(in->dims()[3], filter->dims()[3], paddings[1], strides[1]); - out->Resize( - {in->dims()[0], filter->dims()[0], output_height, output_width}); + outputSize(in_dims[3], filter_dims[3], paddings[1], strides[1]); + ctx->SetOutputDim( + "Output", {in_dims[0], filter_dims[0], output_height, output_width}); } }; class Conv2DOpMaker : public framework::OpProtoAndCheckerMaker { public: - Conv2DOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + Conv2DOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput( "Input", @@ -108,14 +106,15 @@ class Conv2DOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - auto in = ctx.Input("Input"); - auto filter = ctx.Input("Filter"); - auto d_in = ctx.Output(framework::GradVarName("Input")); - auto d_filter = - ctx.Output(framework::GradVarName("Filter")); - if (d_in) d_in->Resize(in->dims()); - if (d_filter) d_filter->Resize(filter->dims()); + void InferShape(framework::InferShapeContextBase* ctx) const override { + auto in_dims = ctx->GetInputDim("Input"); + auto filter_dims = ctx->GetInputDim("Filter"); + if (ctx->HasOutput(framework::GradVarName("Input"))) { + ctx->SetOutputDim(framework::GradVarName("Input"), in_dims); + } + if (ctx->HasOutput(framework::GradVarName("Filter"))) { + ctx->SetOutputDim(framework::GradVarName("Filter"), filter_dims); + } } }; diff --git a/paddle/operators/cos_sim_op.cc b/paddle/operators/cos_sim_op.cc index b56ee2047b811e212b4bf74bf7fbba753a6bcb11..040546f1a6fe1af6d17a5e363a11d27de88d03c2 100644 --- a/paddle/operators/cos_sim_op.cc +++ b/paddle/operators/cos_sim_op.cc @@ -24,22 +24,22 @@ class CosSimOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { + void InferShape(framework::InferShapeContextBase* ctx) const override { // notnull check - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), - "Input(X) of CosSimOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), - "Input(Y) of CosSimOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), - "Output(Out) of CosSimOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("XNorm"), - "Output(XNorm) of CosSimOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("YNorm"), - "Output(YNorm) of CosSimOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of CosSimOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Y"), + "Input(Y) of CosSimOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of CosSimOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("XNorm"), + "Output(XNorm) of CosSimOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("YNorm"), + "Output(YNorm) of CosSimOp should not be null."); // shape check - auto x_dims = ctx.Input("X")->dims(); - auto y_dims = ctx.Input("Y")->dims(); + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); PADDLE_ENFORCE_EQ(x_dims.size(), y_dims.size(), "Ranks of Input(X) and Input(Y) must be equal."); @@ -54,16 +54,16 @@ class CosSimOp : public framework::OperatorWithKernel { " just 1 (which will be broadcasted to match Input(X))."); // resize tensor - ctx.Output("Out")->Resize({x_dims[0], 1}); - ctx.Output("XNorm")->Resize({x_dims[0], 1}); - ctx.Output("YNorm")->Resize({y_dims[0], 1}); - ctx.ShareLoD("X", /*->*/ "Out"); + ctx->SetOutputDim("Out", {x_dims[0], 1}); + ctx->SetOutputDim("XNorm", {x_dims[0], 1}); + ctx->SetOutputDim("YNorm", {y_dims[0], 1}); + ctx->ShareLoD("X", /*->*/ "Out"); } }; class CosSimOpMaker : public framework::OpProtoAndCheckerMaker { public: - CosSimOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + CosSimOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "The 1st input of cos_sim op."); AddInput("Y", "The 2nd input of cos_sim op."); @@ -98,27 +98,23 @@ class CosSimOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { + void InferShape(framework::InferShapeContextBase* ctx) const override { // notnull check - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) must not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("XNorm"), - "Input(XNorm) must not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("YNorm"), - "Input(YNorm) must not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Out"), - "Input(Out) must not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), - "Input(Out@GRAD) must not be null."); + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null."); + PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) must not be null."); + PADDLE_ENFORCE(ctx->HasInput("XNorm"), "Input(XNorm) must not be null."); + PADDLE_ENFORCE(ctx->HasInput("YNorm"), "Input(YNorm) must not be null."); + PADDLE_ENFORCE(ctx->HasInput("Out"), "Input(Out) must not be null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) must not be null."); // shape check - auto x_dims = ctx.Input("X")->dims(); - auto y_dims = ctx.Input("Y")->dims(); - auto xnorm_dims = ctx.Input("XNorm")->dims(); - auto ynorm_dims = ctx.Input("YNorm")->dims(); - auto out_dims = ctx.Input("Out")->dims(); - auto out_grad_dims = - ctx.Input(framework::GradVarName("Out"))->dims(); + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); + auto xnorm_dims = ctx->GetInputDim("XNorm"); + auto ynorm_dims = ctx->GetInputDim("YNorm"); + auto out_dims = ctx->GetInputDim("Out"); + auto out_grad_dims = ctx->GetInputDim(framework::GradVarName("Out")); PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(), "Ranks of Input(X) and Input(Y) must be equal."); @@ -143,10 +139,14 @@ class CosSimOpGrad : public framework::OperatorWithKernel { "Shape of Input(Out@Grad) must be [X.Dim(0), 1]."); // resize tensor - auto *x_grad = ctx.Output(framework::GradVarName("X")); - auto *y_grad = ctx.Output(framework::GradVarName("Y")); - if (x_grad) x_grad->Resize(x_dims); - if (y_grad) y_grad->Resize(y_dims); + auto x_grad_name = framework::GradVarName("X"); + auto y_grad_name = framework::GradVarName("Y"); + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, x_dims); + } + if (ctx->HasOutput(y_grad_name)) { + ctx->SetOutputDim(y_grad_name, y_dims); + } } }; diff --git a/paddle/operators/crop_op.cc b/paddle/operators/crop_op.cc index 52a1123348b10e39bcfa1ba062c893e5f20ed862..9b2305e90e85a6f39d4c584a3251b25f67e81aca 100644 --- a/paddle/operators/crop_op.cc +++ b/paddle/operators/crop_op.cc @@ -25,16 +25,14 @@ class CropOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), - "Input(X) of CropOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), - "Output(Out) of CropOp should not be null."); - auto x_dim = ctx.Input("X")->dims(); - auto *y = ctx.Input("Y"); - auto *out = ctx.Output("Out"); - if (y == nullptr) { - auto shape = Attr>("shape"); + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of CropOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of CropOp should not be null."); + auto x_dim = ctx->GetInputDim("X"); + if (!ctx->HasInput("Y")) { + auto shape = ctx->Attrs().Get>("shape"); PADDLE_ENFORCE_EQ( int64_t(shape.size()), x_dim.size(), "Shape size should be equal to dimention size of input tensor."); @@ -42,19 +40,20 @@ class CropOp : public framework::OperatorWithKernel { for (size_t i = 0; i < shape.size(); ++i) { tensor_shape[i] = static_cast(shape[i]); } - out->Resize(framework::make_ddim(tensor_shape)); + ctx->SetOutputDim("Out", framework::make_ddim(tensor_shape)); } else { - PADDLE_ENFORCE_EQ(framework::arity(x_dim), framework::arity(y->dims()), + auto y_dim = ctx->GetInputDim("Y"); + PADDLE_ENFORCE_EQ(framework::arity(x_dim), framework::arity(y_dim), "Tensor rank of both CropOp's " "inputs must be same."); - out->Resize(y->dims()); + ctx->SetOutputDim("Out", y_dim); } } }; class CropOpMaker : public framework::OpProtoAndCheckerMaker { public: - CropOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + CropOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "The input of pad op. " @@ -78,12 +77,12 @@ class CropOpMaker : public framework::OpProtoAndCheckerMaker { Crop Operator. Crop input into output, as specified by offsets and shape. -There are two ways to set shape: +There are two ways to set shape: 1. referenc input: crop input X as shape as reference input. - The dimension of reference input should + The dimension of reference input should be as same as input X. 2. shape list: crop input X by shape described by a list. - The size of shape list should be as same as + The size of shape list should be as same as dimension size of input X. The input should be a k-D tensor(k > 0 and k < 7). As an example: @@ -94,15 +93,15 @@ Given: [0, 3, 4, 0, 0] [0, 0, 0, 0, 0]] -and +and offsets = [0, 1] and - + shape = [2, 2] -then we get +then we get Out = [[1, 2], [3, 4]] @@ -116,14 +115,14 @@ class CropOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null"); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), - "Input(Out@GRAD) should not be null"); - auto x_dims = ctx.Input("X")->dims(); - auto *x_grad = ctx.Output(framework::GradVarName("X")); - if (x_grad != nullptr) { - x_grad->Resize(x_dims); + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null"); + auto x_dims = ctx->GetInputDim("X"); + auto x_grad_name = framework::GradVarName("X"); + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, x_dims); } } }; diff --git a/paddle/operators/cross_entropy_op.cc b/paddle/operators/cross_entropy_op.cc index 2e16201e74c153888594ebe6679fb0036734dad4..26fc9b51c44d21d92851030449e116538f937846 100644 --- a/paddle/operators/cross_entropy_op.cc +++ b/paddle/operators/cross_entropy_op.cc @@ -22,33 +22,30 @@ class CrossEntropyOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should be not null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), - "Input(Label) should be not null."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Y"), - "Output(Y) should be not null."); - - auto x = ctx.Input("X"); - auto label = ctx.Input("Label"); - PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank should be 2."); - PADDLE_ENFORCE_EQ(label->dims().size(), 2, - "Input(Label)'s rank should be 2."); - PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0], + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null."); + PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null."); + PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) should be not null."); + + auto x_dims = ctx->GetInputDim("X"); + auto label_dims = ctx->GetInputDim("Label"); + PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2."); + PADDLE_ENFORCE_EQ(label_dims.size(), 2, "Input(Label)'s rank should be 2."); + PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0], "The 1st dimension of Input(X) and Input(Label) should " "be equal."); - if (ctx.Attr("softLabel")) { - PADDLE_ENFORCE_EQ(x->dims()[1], label->dims()[1], + if (ctx->Attrs().Get("softLabel")) { + PADDLE_ENFORCE_EQ(x_dims[1], label_dims[1], "If Attr(softLabel) == true, the 2nd dimension of " "Input(X) and Input(Label) should be equal."); } else { - PADDLE_ENFORCE_EQ(label->dims()[1], 1, + PADDLE_ENFORCE_EQ(label_dims[1], 1, "If Attr(softLabel) == false, the 2nd dimension of " "Input(Label) should be 1."); } - ctx.Output("Y")->Resize({x->dims()[0], 1}); - ctx.ShareLoD("X", /*->*/ "Y"); + ctx->SetOutputDim("Y", {x_dims[0], 1}); + ctx->ShareLoD("X", /*->*/ "Y"); } }; @@ -57,50 +54,45 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should be not null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), - "Input(Label) should be not null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Y")), - "Input(Y@GRAD) shoudl be not null."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar(framework::GradVarName("X")), - "Output(X@GRAD) should be not null."); - - auto x = ctx.Input("X"); - auto label = ctx.Input("Label"); - auto dy = ctx.Input(framework::GradVarName("Y")); - PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank should be 2."); - PADDLE_ENFORCE_EQ(dy->dims().size(), 2, - "Input(Y@Grad)'s rank should be 2."); - PADDLE_ENFORCE_EQ(label->dims().size(), 2, - "Input(Label)'s rank should be 2."); - PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0], + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null."); + PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")), + "Input(Y@GRAD) shoudl be not null."); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), + "Output(X@GRAD) should be not null."); + + auto x_dims = ctx->GetInputDim("X"); + auto label_dims = ctx->GetInputDim("Label"); + auto dy_dims = ctx->GetInputDim(framework::GradVarName("Y")); + PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2."); + PADDLE_ENFORCE_EQ(dy_dims.size(), 2, "Input(Y@Grad)'s rank should be 2."); + PADDLE_ENFORCE_EQ(label_dims.size(), 2, "Input(Label)'s rank should be 2."); + PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0], "The 1st dimension of Input(X) and Input(Label) should " "be equal."); - PADDLE_ENFORCE_EQ(x->dims()[0], dy->dims()[0], + PADDLE_ENFORCE_EQ(x_dims[0], dy_dims[0], "The 1st dimension of Input(X) and Input(Y@Grad) should " "be equal."); - PADDLE_ENFORCE_EQ(dy->dims()[1], 1, + PADDLE_ENFORCE_EQ(dy_dims[1], 1, "The 2nd dimension of Input(Y@Grad) should be 1."); - if (ctx.Attr("softLabel")) { - PADDLE_ENFORCE_EQ(x->dims()[1], label->dims()[1], + if (ctx->Attrs().Get("softLabel")) { + PADDLE_ENFORCE_EQ(x_dims[1], label_dims[1], "When Attr(softLabel) == true, the 2nd dimension of " "Input(X) and Input(Label) should be equal."); } else { - PADDLE_ENFORCE_EQ(label->dims()[1], 1, + PADDLE_ENFORCE_EQ(label_dims[1], 1, "When Attr(softLabel) == false, the 2nd dimension of " "Input(Label) should be 1."); } - - auto dx = ctx.Output(framework::GradVarName("X")); - dx->Resize(x->dims()); + ctx->SetOutputDim(framework::GradVarName("X"), x_dims); } }; class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker { public: - CrossEntropyOpMaker(framework::OpProto *proto, - framework::OpAttrChecker *op_checker) + CrossEntropyOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "(Tensor, default Tensor), a 2-D tensor with shape N x D, " diff --git a/paddle/operators/dropout_op.cc b/paddle/operators/dropout_op.cc index 2130eda6a42c893d8ec251a7022a0bfa44433bb7..a669b5cf00f1f4ad351486e2977bf8a76aa5bf62 100644 --- a/paddle/operators/dropout_op.cc +++ b/paddle/operators/dropout_op.cc @@ -24,25 +24,25 @@ class DropoutOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null."); - PADDLE_ENFORCE_GE(ctx.Attr("dropout_prob"), 0); - PADDLE_ENFORCE_LE(ctx.Attr("dropout_prob"), 1); - - auto dims = ctx.Input("X")->dims(); - ctx.Output("Out")->Resize(dims); - if (ctx.Attr("is_training")) { - ctx.Output("Mask")->Resize(dims); + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null."); + PADDLE_ENFORCE_GE(ctx->Attrs().Get("dropout_prob"), 0); + PADDLE_ENFORCE_LE(ctx->Attrs().Get("dropout_prob"), 1); + + auto x_dims = ctx->GetInputDim("X"); + ctx->SetOutputDim("Out", x_dims); + if (ctx->Attrs().Get("is_training") == 1) { + ctx->SetOutputDim("Mask", x_dims); } - ctx.ShareLoD("X", /*->*/ "Out"); + ctx->ShareLoD("X", /*->*/ "Out"); } }; template class DropoutOpMaker : public framework::OpProtoAndCheckerMaker { public: - DropoutOpMaker(framework::OpProto *proto, - framework::OpAttrChecker *op_checker) + DropoutOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddAttr("dropout_prob", "Probability of setting units to zero.") .SetDefault(.5f); @@ -70,27 +70,26 @@ class DropoutOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE(ctx.Attr("is_training"), - "GradOp is only callable when is_training is true"); - - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Mask"), "Mask must not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), - "Input(Out@GRAD) must not be null."); - - PADDLE_ENFORCE_GE(ctx.Attr("dropout_prob"), 0); - PADDLE_ENFORCE_LE(ctx.Attr("dropout_prob"), 1); - auto x_dims = ctx.Input("X")->dims(); - auto out_dims = ctx.Input(framework::GradVarName("Out"))->dims(); + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE_EQ(ctx->Attrs().Get("is_training"), 1, + "GradOp is only callable when is_training is true"); + + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null."); + PADDLE_ENFORCE(ctx->HasInput("Mask"), "Mask must not be null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) must not be null."); + + PADDLE_ENFORCE_GE(ctx->Attrs().Get("dropout_prob"), 0); + PADDLE_ENFORCE_LE(ctx->Attrs().Get("dropout_prob"), 1); + auto x_dims = ctx->GetInputDim("X"); + auto out_dims = ctx->GetInputDim(framework::GradVarName("Out")); PADDLE_ENFORCE_EQ(x_dims, out_dims, "Dimensions of Input(X) and Out@Grad must be the same."); - auto mask_dims = ctx.Input("Mask")->dims(); + auto mask_dims = ctx->GetInputDim("Mask"); PADDLE_ENFORCE_EQ(x_dims, mask_dims, "Dimensions of Input(X) and Mask must be the same."); - auto *x_grad = ctx.Output(framework::GradVarName("X")); - x_grad->Resize(x_dims); + ctx->SetOutputDim(framework::GradVarName("X"), x_dims); } }; diff --git a/paddle/operators/elementwise_op.h b/paddle/operators/elementwise_op.h index f224722c1bec6716e68de9da2509250f7d4b37ae..c4777a00d6781ea751123b2efffb6df8e29630b0 100644 --- a/paddle/operators/elementwise_op.h +++ b/paddle/operators/elementwise_op.h @@ -202,21 +202,20 @@ class ElementwiseOp : public framework::OperatorWithKernel { protected: using Tensor = framework::Tensor; - void InferShape(const framework::InferShapeContext& ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), - "Input(X) of elementwise op should not be null"); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), - "Input(Y) of elementwise op should not be null"); - PADDLE_ENFORCE_NOT_NULL( - ctx.OutputVar("Out"), - "Output(Out) of elementwise op should not be null."); - - auto x_dim = ctx.Input("X")->dims(); - auto y_dim = ctx.Input("Y")->dims(); + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of elementwise op should not be null"); + PADDLE_ENFORCE(ctx->HasInput("Y"), + "Input(Y) of elementwise op should not be null"); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of elementwise op should not be null."); + + auto x_dim = ctx->GetInputDim("X"); + auto y_dim = ctx->GetInputDim("Y"); PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(), "Rank of first input must >= rank of second input.") - ctx.Output("Out")->Resize(x_dim); - ctx.ShareLoD("X", /*->*/ "Out"); + ctx->SetOutputDim("Out", x_dim); + ctx->ShareLoD("X", /*->*/ "Out"); } }; @@ -234,7 +233,7 @@ must be small or equal to X's dimensions. )DOC"); AddAttr("axis", R"DOC( -When the shape(Y) does not equal the shape(X),Y will be broadcasted +When the shape(Y) does not equal the shape(X),Y will be broadcasted to match the shape of X and axis should be dimension index Y in X )DOC") .SetDefault(-1) @@ -244,7 +243,7 @@ to match the shape of X and axis should be dimension index Y in X comment_ = R"DOC( Limited elementwise {name} operator.The equation is: Out = {equation}. 1. The shape of Y should be same with X or -2. Y's shape is a subset of X. +2. Y's shape is a subset of X. Y will be broadcasted to match the shape of X and axis should be dimension index Y in X. example: @@ -284,27 +283,26 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel { using Tensor = framework::Tensor; protected: - void InferShape(const framework::InferShapeContext& ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null"); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) should not be null"); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), - "Input(Out@GRAD) should not be null"); - - auto x_dims = ctx.Input("X")->dims(); - auto y_dims = ctx.Input("Y")->dims(); - auto out_dims = ctx.Input(framework::GradVarName("Out"))->dims(); - auto* x_grad = ctx.Output(framework::GradVarName("X")); - auto* y_grad = ctx.Output(framework::GradVarName("Y")); + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); + PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null"); + + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); + auto out_dims = ctx->GetInputDim(framework::GradVarName("Out")); PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(), "Rank of first input must >= rank of second input.") - if (x_grad) { - x_grad->Resize(x_dims); + auto x_grad_name = framework::GradVarName("X"); + auto y_grad_name = framework::GradVarName("Y"); + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, x_dims); } - - if (y_grad) { - y_grad->Resize(y_dims); + if (ctx->HasOutput(y_grad_name)) { + ctx->SetOutputDim(y_grad_name, y_dims); } } }; diff --git a/paddle/operators/fill_zeros_like_op.cc b/paddle/operators/fill_zeros_like_op.cc index 761a527a5574edc779340ec595dfe1bc1964438a..e164de6584e7350283781019cc74118c2d13646e 100644 --- a/paddle/operators/fill_zeros_like_op.cc +++ b/paddle/operators/fill_zeros_like_op.cc @@ -22,15 +22,13 @@ class FillZerosLikeOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), - "Input(X) of FillZerosLikeOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Y"), - "Output(Y) of FillZerosLikeOp should not be null."); - - ctx.Output("Y")->Resize( - ctx.Input("X")->dims()); - ctx.ShareLoD("X", /*->*/ "Y"); + void InferShape(framework::InferShapeContextBase *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of FillZerosLikeOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Y"), + "Output(Y) of FillZerosLikeOp should not be null."); + ctx->SetOutputDim("Y", ctx->GetInputDim("X")); + ctx->ShareLoD("X", /*->*/ "Y"); } }; diff --git a/paddle/operators/gather_op.cc b/paddle/operators/gather_op.cc index fecd1ce2147a1e6f2f7928266be74ed7b647c5b9..0e3cd174adee1e50d0a63861286a26d325484efb 100644 --- a/paddle/operators/gather_op.cc +++ b/paddle/operators/gather_op.cc @@ -23,19 +23,19 @@ class GatherOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), - "Input(X) of GatherOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Index"), - "Input(Index) of GatherOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), - "Output(Out) of GatherOp should not be null."); + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of GatherOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Index"), + "Input(Index) of GatherOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of GatherOp should not be null."); - int batch_size = ctx.Input("Index")->dims()[0]; + int batch_size = ctx->GetInputDim("Index")[0]; PADDLE_ENFORCE_GE(batch_size, 0, "Batch size must be >0"); - framework::DDim output_dims(ctx.Input("X")->dims()); + framework::DDim output_dims(ctx->GetInputDim("X")); output_dims[0] = batch_size; - ctx.Output("Out")->Resize(output_dims); + ctx->SetOutputDim("Out", output_dims); } }; @@ -44,23 +44,20 @@ class GatherGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - auto X_grad = ctx.Output(framework::GradVarName("X")); - auto X = ctx.Input("X"); - - X_grad->Resize(X->dims()); + void InferShape(framework::InferShapeContextBase* ctx) const override { + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); } }; class GatherOpMaker : public framework::OpProtoAndCheckerMaker { public: - GatherOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + GatherOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "The source input of gather op"); AddInput("Index", "The index input of gather op"); AddOutput("Out", "The output of add op"); AddComment(R"DOC( -Gather Operator by selecting from the first axis, +Gather Operator by selecting from the first axis, Out = X[Index] )DOC"); diff --git a/paddle/operators/gaussian_random_op.cc b/paddle/operators/gaussian_random_op.cc index 5b7cbb5cc7bcb7e43b15363d37d7b8f2cbf0fbdc..05120a6e7bcfdb8641c722731f462c89e4223339 100644 --- a/paddle/operators/gaussian_random_op.cc +++ b/paddle/operators/gaussian_random_op.cc @@ -43,13 +43,10 @@ class GaussianRandomOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext& ctx) const override { - PADDLE_ENFORCE_NOT_NULL( - ctx.OutputVar("Out"), - "Output(Out) of GaussianRandomOp should not be null."); - - auto* tensor = ctx.Output("Out"); - auto dims = Attr>("dims"); + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of GaussianRandomOp should not be null."); + auto dims = ctx->Attrs().Get>("dims"); std::vector temp; temp.reserve(dims.size()); for (auto dim : dims) { @@ -57,7 +54,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel { } PADDLE_ENFORCE(dims.size() > 0UL, "dims can be one int or array. dims must be set."); - tensor->Resize(framework::make_ddim(temp)); + ctx->SetOutputDim("Out", framework::make_ddim(temp)); } }; diff --git a/paddle/operators/lookup_table_op.cc b/paddle/operators/lookup_table_op.cc index 04ac24662e9cfec6a49cd213cb76bdebc7b730c8..9b1314bfbade8551d98b0fbabb7c2968d7600db5 100644 --- a/paddle/operators/lookup_table_op.cc +++ b/paddle/operators/lookup_table_op.cc @@ -22,27 +22,26 @@ class LookupTableOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("W"), - "Input(W) of LookupTableOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Ids"), - "Input(Ids) of LookupTableOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), - "Output(Out) of LookupTableOp should not be null."); - - auto table_t = ctx.Input("W"); - auto ids_t = ctx.Input("Ids"); - auto output_t = ctx.Output("Out"); - - output_t->Resize({ids_t->dims()[0], table_t->dims()[1]}); - ctx.ShareLoD("Ids", /*->*/ "Out"); + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("W"), + "Input(W) of LookupTableOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Ids"), + "Input(Ids) of LookupTableOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of LookupTableOp should not be null."); + + auto table_dims = ctx->GetInputDim("W"); + auto ids_dims = ctx->GetInputDim("Ids"); + + ctx->SetOutputDim("Out", {ids_dims[0], table_dims[1]}); + ctx->ShareLoD("Ids", /*->*/ "Out"); } }; class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker { public: - LookupTableOpMaker(framework::OpProto *proto, - framework::OpAttrChecker *op_checker) + LookupTableOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("W", "An input represents embedding tensors," @@ -66,11 +65,9 @@ class LookupTableOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &context) const override { - auto table = context.Input("W"); - auto d_table = - context.Output(framework::GradVarName("W")); - d_table->Resize(table->dims()); + void InferShape(framework::InferShapeContextBase* ctx) const override { + auto table_dims = ctx->GetInputDim("W"); + ctx->SetOutputDim(framework::GradVarName("W"), table_dims); } }; diff --git a/paddle/operators/lstm_unit_op.cc b/paddle/operators/lstm_unit_op.cc index 3600f199770c4b8c9a6561b4c270a91bc8b20c0b..bd75b001cb87d914f6c56ea35dcb5013d68145b2 100644 --- a/paddle/operators/lstm_unit_op.cc +++ b/paddle/operators/lstm_unit_op.cc @@ -22,37 +22,36 @@ class LstmUnitOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), - "Input(X) of LSTM should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("C_prev"), - "Input(C_prev) of LSTM should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("C"), - "Output(C) of LSTM should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("H"), - "Output(H) of LSTM should not be null."); - - auto *x = ctx.Input("X"); - auto *c_prev = ctx.Input("C_prev"); - - PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank must be 2."); - PADDLE_ENFORCE(x->dims()[0] == c_prev->dims()[0], + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasInput("C_prev"), + "Input(C_prev) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("C"), + "Output(C) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("H"), + "Output(H) of LSTM should not be null."); + + auto x_dims = ctx->GetInputDim("X"); + auto c_prev_dims = ctx->GetInputDim("C_prev"); + + PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2."); + PADDLE_ENFORCE(x_dims[0] == c_prev_dims[0], "Batch size of inputs and states must be equal"); - PADDLE_ENFORCE(x->dims()[1] == c_prev->dims()[1] * 4, + PADDLE_ENFORCE(x_dims[1] == c_prev_dims[1] * 4, "Dimension of FC should equal to prev state * 4"); - int b_size = c_prev->dims()[0]; // batch size - int s_dim = c_prev->dims()[1]; // state dim - ctx.Output("C")->Resize({b_size, s_dim}); - ctx.Output("H")->Resize({b_size, s_dim}); + int b_size = c_prev_dims[0]; // batch size + int s_dim = c_prev_dims[1]; // state dim + ctx->SetOutputDim("C", {b_size, s_dim}); + ctx->SetOutputDim("H", {b_size, s_dim}); } }; template class LstmUnitOpMaker : public framework::OpProtoAndCheckerMaker { public: - LstmUnitOpMaker(framework::OpProto *proto, - framework::OpAttrChecker *op_checker) + LstmUnitOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "FC input before the non-linear activation."); AddInput( @@ -63,11 +62,11 @@ class LstmUnitOpMaker : public framework::OpProtoAndCheckerMaker { AddComment(R"DOC(Lstm-Unit Operator -Equation: +Equation: i, f, o, j = split(X) C = C_prev * sigm(f + forget_bias) + sigm(i) * tanh(j) H = C * sigm(o) - + )DOC"); AddAttr("forget_bias", "The forget bias of Lstm Unit.") .SetDefault(0.0); @@ -79,15 +78,14 @@ class LstmUnitGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("C")), - "Input(C@GRAD) should not be null"); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("H")), - "Input(H@GRAD) should not be null"); - ctx.Output(framework::GradVarName("X")) - ->Resize(ctx.Input("X")->dims()); - ctx.Output(framework::GradVarName("C_prev")) - ->Resize(ctx.Input("C_prev")->dims()); + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("C")), + "Input(C@GRAD) should not be null"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("H")), + "Input(H@GRAD) should not be null"); + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + ctx->SetOutputDim(framework::GradVarName("C_prev"), + ctx->GetInputDim("C_prev")); } }; diff --git a/paddle/operators/mean_op.cc b/paddle/operators/mean_op.cc index b04384bda81b93f5db0be3206eee10ad5e854540..d799239d4ed6d230578c77921a1a454b476b63fa 100644 --- a/paddle/operators/mean_op.cc +++ b/paddle/operators/mean_op.cc @@ -22,18 +22,18 @@ class MeanOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), - "Input(X) of MeanOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), - "Output(Out) of MeanOp should not be null."); - ctx.Output("Out")->Resize({1}); + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of MeanOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of MeanOp should not be null."); + ctx->SetOutputDim("Out", {1}); } }; class MeanOpMaker : public framework::OpProtoAndCheckerMaker { public: - MeanOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + MeanOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "The input of mean op"); AddOutput("Out", "The output of mean op").NotInGradient(); @@ -47,9 +47,8 @@ class MeanGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - ctx.Output(framework::GradVarName("X")) - ->Resize(ctx.Input("X")->dims()); + void InferShape(framework::InferShapeContextBase* ctx) const override { + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); } }; diff --git a/paddle/operators/minus_op.cc b/paddle/operators/minus_op.cc index 29cb85489bd05f6c1e7143d962eac0af26e75825..ce049d4d7bd96a6758d71b381e6e6b4edbcc8b5c 100644 --- a/paddle/operators/minus_op.cc +++ b/paddle/operators/minus_op.cc @@ -26,22 +26,22 @@ class MinusOp : public framework::OperatorWithKernel { : OperatorWithKernel(type, inputs, outputs, attrs) {} protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), - "Input(X) of MinusOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), - "Input(Y) of MinusOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), - "Output(Out) of MinusOp should not be null."); + void InferShape(framework::InferShapeContextBase *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of MinusOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Y"), + "Input(Y) of MinusOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of MinusOp should not be null."); - auto *left_tensor = ctx.Input("X"); - auto *right_tensor = ctx.Input("Y"); + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); PADDLE_ENFORCE_EQ( - left_tensor->numel(), right_tensor->numel(), + x_dims, y_dims, "Minus operator must take two tensor with same num of elements"); - ctx.Output("Out")->Resize(left_tensor->dims()); - ctx.ShareLoD("X", /*->*/ "Out"); + ctx->SetOutputDim("Out", x_dims); + ctx->ShareLoD("X", /*->*/ "Out"); } }; diff --git a/paddle/operators/modified_huber_loss_op.cc b/paddle/operators/modified_huber_loss_op.cc index 8606c0d1e1bf7a52299528d30af0367d9f93edd2..84212a2b3be1ac3664ebd77c7a0ae4d86abad3a0 100644 --- a/paddle/operators/modified_huber_loss_op.cc +++ b/paddle/operators/modified_huber_loss_op.cc @@ -22,20 +22,19 @@ class ModifiedHuberLossOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext& context) const override { - PADDLE_ENFORCE_NOT_NULL(context.InputVar("X"), "X must be initialized."); - PADDLE_ENFORCE_NOT_NULL(context.InputVar("Y"), "Y must be initialized."); + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "X must be initialized."); + PADDLE_ENFORCE(ctx->HasInput("Y"), "Y must be initialized."); - auto* x = context.Input("X"); - auto* y = context.Input("Y"); + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); - PADDLE_ENFORCE_EQ(x->dims(), y->dims(), - "The shape of X and Y must be the same."); - PADDLE_ENFORCE_EQ(x->dims().size(), 2, "The tensor rank of X must be 2."); - PADDLE_ENFORCE_EQ(x->dims()[1], 1, "The 2nd dimension of X must be 1."); + PADDLE_ENFORCE_EQ(x_dims, y_dims, "The shape of X and Y must be the same."); + PADDLE_ENFORCE_EQ(x_dims.size(), 2, "The tensor rank of X must be 2."); + PADDLE_ENFORCE_EQ(x_dims[1], 1, "The 2nd dimension of X must be 1."); - context.Output("IntermediateVal")->Resize(x->dims()); - context.Output("Out")->Resize({x->dims()[0], 1}); + ctx->SetOutputDim("IntermediateVal", x_dims); + ctx->SetOutputDim("Out", {x_dims[0], 1}); } }; @@ -75,27 +74,28 @@ class ModifiedHuberLossGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext& context) const override { - auto* x = context.Input("X"); - auto* y = context.Input("Y"); - auto* intermediate_val = context.Input("IntermediateVal"); - auto* out_grad = context.Input(framework::GradVarName("Out")); - auto* x_grad = - context.Output(framework::GradVarName("X")); - - PADDLE_ENFORCE_NOT_NULL(x, "X must be initialized."); - PADDLE_ENFORCE_NOT_NULL(y, "Y must be initialized."); - PADDLE_ENFORCE_NOT_NULL(intermediate_val, - "Intermediate value must not be null."); - PADDLE_ENFORCE_NOT_NULL(out_grad, "Input(Out@Grad) must not be null."); + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "X must be initialized."); + PADDLE_ENFORCE(ctx->HasInput("Y"), "Y must be initialized."); + PADDLE_ENFORCE(ctx->HasInput("IntermediateVal"), + "Intermediate value must not be null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@Grad) must not be null."); + + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); + auto intermediate_dims = ctx->GetInputDim("IntermediateVal"); + auto out_grad_dims = ctx->GetInputDim(framework::GradVarName("Out")); PADDLE_ENFORCE_EQ( - intermediate_val->dims(), x->dims(), + intermediate_dims, x_dims, "The shape of X and intermediate value must be the same."); - PADDLE_ENFORCE_EQ(out_grad->dims(), x->dims(), + PADDLE_ENFORCE_EQ(out_grad_dims, x_dims, "The shape of Input(Out@Grad) and X must be the same."); - if (x_grad) x_grad->Resize(x->dims()); + if (ctx->HasOutput(framework::GradVarName("X"))) { + ctx->SetOutputDim(framework::GradVarName("X"), x_dims); + } } }; diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc index 7047718a3f1bf7e9598952efa1d9bcb20d5cf5b4..9858c4d9c2195c7bd0e767aaa86a950e0a791443 100644 --- a/paddle/operators/mul_op.cc +++ b/paddle/operators/mul_op.cc @@ -24,27 +24,23 @@ class MulOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), - "Input(X) of MulOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), - "Input(Y) of MulOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), - "Output(Out) of MulOp should not be null."); - - auto x_dims = ctx.Input("X")->dims(); - auto y_dims = ctx.Input("Y")->dims(); - int x_num_col_dims = Attr("x_num_col_dims"); - int y_num_col_dims = Attr("y_num_col_dims"); + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of MulOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) of MulOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of MulOp should not be null."); + + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); + int x_num_col_dims = ctx->Attrs().Get("x_num_col_dims"); + int y_num_col_dims = ctx->Attrs().Get("y_num_col_dims"); PADDLE_ENFORCE(x_dims.size() > x_num_col_dims, - "The rank of input tensor X(%s) should be larger than " - "`mul_op`'s `x_num_col_dims`.", - ctx.op().Input("X")); + "The rank of input tensor X should be larger than " + "`mul_op`'s `x_num_col_dims`."); PADDLE_ENFORCE(y_dims.size() > y_num_col_dims, - "The rank of input tensor Y(%s) should be larger than " - "`mul_op`'s `y_num_col_dims`.", - ctx.op().Input("Y")); + "The rank of input tensor Y should be larger than " + "`mul_op`'s `y_num_col_dims`."); auto x_mat_dims = framework::flatten_to_2d(x_dims, x_num_col_dims); auto y_mat_dims = framework::flatten_to_2d(y_dims, y_num_col_dims); @@ -52,24 +48,23 @@ class MulOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ( x_mat_dims[1], y_mat_dims[0], "First matrix's width must be equal with second matrix's height."); - ctx.Output("Out")->Resize( - {x_mat_dims[0], y_mat_dims[1]}); - ctx.ShareLoD("X", /*->*/ "Out"); + ctx->SetOutputDim("Out", {x_mat_dims[0], y_mat_dims[1]}); + ctx->ShareLoD("X", /*->*/ "Out"); } }; class MulOpMaker : public framework::OpProtoAndCheckerMaker { public: - MulOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + MulOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "The first input of mul op"); AddInput("Y", "The second input of mul op"); AddOutput("Out", "The output of mul op"); AddAttr( "x_num_col_dims", - R"DOC(mul_op can take tensors with more than two dimensions as input `X`, - in that case, tensors will be reshaped to a matrix. The matrix's first - dimension(column length) will be the product of tensor's last + R"DOC(mul_op can take tensors with more than two dimensions as input `X`, + in that case, tensors will be reshaped to a matrix. The matrix's first + dimension(column length) will be the product of tensor's last `num_col_dims` dimensions, and the matrix's second dimension(row length) will be the product of tensor's first `rank - num_col_dims` dimensions. )DOC") @@ -100,16 +95,14 @@ class MulOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null"); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) should not be null"); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), - "Input(Out@GRAD) should not be null"); - auto x_dims = ctx.Input("X")->dims(); - auto y_dims = ctx.Input("Y")->dims(); - auto out_dims = ctx.Input(framework::GradVarName("Out"))->dims(); - auto *x_grad = ctx.Output(framework::GradVarName("X")); - auto *y_grad = ctx.Output(framework::GradVarName("Y")); + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); + PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null"); + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); + auto out_dims = ctx->GetInputDim(framework::GradVarName("Out")); auto x_mat_dims = framework::flatten_to_2d(x_dims, Attr("x_num_col_dims")); @@ -125,8 +118,15 @@ class MulOpGrad : public framework::OperatorWithKernel { "The second dimension of Out@GRAD must equal to the second " "dimension of the second operand."); - if (x_grad) x_grad->Resize(x_dims); - if (y_grad) y_grad->Resize(y_dims); + auto x_grad_name = framework::GradVarName("X"); + auto y_grad_name = framework::GradVarName("Y"); + + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, x_dims); + } + if (ctx->HasOutput(y_grad_name)) { + ctx->SetOutputDim(y_grad_name, y_dims); + } } }; diff --git a/paddle/operators/multiplex_op.cc b/paddle/operators/multiplex_op.cc index 7b50444d16dc57fd14b918d1159e3e21ecd1f1c4..9896d269ccc86d8fdc3bf6375e44ef5bf3e6b9c7 100644 --- a/paddle/operators/multiplex_op.cc +++ b/paddle/operators/multiplex_op.cc @@ -24,41 +24,38 @@ class MultiplexOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Ids"), - "Input(Ids) shouldn't be null."); - PADDLE_ENFORCE(!ctx.MultiInputVar("X").empty(), + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Ids"), "Input(Ids) shouldn't be null."); + PADDLE_ENFORCE(!ctx->Inputs("X").empty(), "MultiInput(X) shouldn't be empty."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), - "Output(Out) shouldn't be null."); - auto ids_dim = ctx.Input("Ids")->dims(); + PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) shouldn't be null."); + auto ids_dim = ctx->GetInputDim("Ids"); PADDLE_ENFORCE( ids_dim.size() == 2 && ids_dim[1] == 1, "The index tensor must be a vector with size batchSize x 1."); - auto ins = ctx.MultiInput("X"); - auto *out = ctx.Output("Out"); - auto num_ins = ins.size(); + auto ins_dims = ctx->GetInputsDim("X"); + auto num_ins = ins_dims.size(); PADDLE_ENFORCE(num_ins > 1, "multiplex operator should have more than " "one candidate input tensors."); - auto in_dim = ins[0]->dims(); + auto in_dim = ins_dims[0]; PADDLE_ENFORCE(in_dim.size() >= 2, "The rank of candidate tensors must be not less than 2."); for (size_t i = 1; i < num_ins; i++) { - auto dim = ins[i]->dims(); + auto dim = ins_dims[i]; PADDLE_ENFORCE(in_dim == dim, "All the candidate tensors must have the same size."); } - out->Resize(in_dim); + ctx->SetOutputDim("Out", in_dim); } }; class MultiplexOpMaker : public framework::OpProtoAndCheckerMaker { public: - MultiplexOpMaker(framework::OpProto *proto, - framework::OpAttrChecker *op_checker) + MultiplexOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("Ids", "The index tensor of multiplex operator."); AddInput("X", "The candidate tensors of multiplex operator.") @@ -88,21 +85,19 @@ class MultiplexGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE(!ctx.MultiInputVar("X").empty(), - "Input(X) should not be null."); - PADDLE_ENFORCE(!ctx.MultiOutputVar(framework::GradVarName("X")).empty(), + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(!ctx->Inputs("X").empty(), "Input(X) should not be null."); + PADDLE_ENFORCE(!ctx->Outputs(framework::GradVarName("X")).empty(), "Output(X@Grad) should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), - "Input(Out@GRAD) should not be null."); - auto d_ins = ctx.MultiOutput(framework::GradVarName("X")); - auto ins = ctx.MultiInput("X"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null."); + std::vector d_ins; + auto ins = ctx->GetInputsDim("X"); // No need to compute gradient for Input(Ids) for (size_t i = 0; i < ins.size(); i++) { - if (d_ins[i]) { - d_ins[i]->Resize(ins[i]->dims()); - } + d_ins.push_back(ins[i]); } + ctx->SetOutputsDim(framework::GradVarName("X"), d_ins); } }; diff --git a/paddle/operators/pad_op.cc b/paddle/operators/pad_op.cc index 375d8a35acc0716259071c31bc332fdf5aabce1c..04ebb14f6ee6c73f48aa2f75811a22f9b8a25006 100644 --- a/paddle/operators/pad_op.cc +++ b/paddle/operators/pad_op.cc @@ -24,14 +24,13 @@ class PadOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), - "Input(X) of PadOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), - "Output(Out) of PadOp should not be null."); - - auto x_dim = ctx.Input("X")->dims(); - auto paddings = Attr>("paddings"); + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of PadOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of PadOp should not be null."); + + auto x_dim = ctx->GetInputDim("X"); + auto paddings = ctx->Attrs().Get>("paddings"); PADDLE_ENFORCE_EQ(x_dim.size() * 2, int64_t(paddings.size()), "Size of paddings should be equal to 2 * dimension size " "of input tensor."); @@ -39,19 +38,18 @@ class PadOp : public framework::OperatorWithKernel { for (int i = 0; i < x_dim.size(); ++i) { out_dims[i] = x_dim[i] + paddings[i * 2] + paddings[i * 2 + 1]; } - ctx.Output("Out")->Resize( - framework::make_ddim(out_dims)); + ctx->SetOutputDim("Out", framework::make_ddim(out_dims)); if (out_dims[0] == x_dim[0]) { // Only pass LoD when the first dimension is equal between // output and input. - ctx.ShareLoD("X", /*->*/ "Out"); + ctx->ShareLoD("X", /*->*/ "Out"); } } }; class PadOpMaker : public framework::OpProtoAndCheckerMaker { public: - PadOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + PadOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "The input of pad op. " @@ -68,15 +66,15 @@ Given: X = [[1, 2], [3, 4]] -and +and paddings = [0, 1, 1, 2] and - -pad_value = 0 -then we get +pad_value = 0 + +then we get Out = [[0, 1, 2, 0, 0] [0, 3, 4, 0, 0] @@ -101,14 +99,14 @@ class PadOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null"); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), - "Input(Out@GRAD) should not be null"); - auto x_dims = ctx.Input("X")->dims(); - auto *x_g = ctx.Output(framework::GradVarName("X")); - if (x_g != nullptr) { - x_g->Resize(x_dims); + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null"); + auto x_dims = ctx->GetInputDim("X"); + auto x_grad_name = framework::GradVarName("X"); + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, x_dims); } } }; diff --git a/paddle/operators/prelu_op.cc b/paddle/operators/prelu_op.cc index 912196c190b5ddbd4e3482a5314e949186b94368..1692464f2833a59243ccc1598422180262a59282 100644 --- a/paddle/operators/prelu_op.cc +++ b/paddle/operators/prelu_op.cc @@ -26,19 +26,14 @@ class PReluOp : public framework::OperatorWithKernel { : OperatorWithKernel(type, inputs, outputs, attrs) {} protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null"); - auto *in = ctx.Input("X"); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Alpha"), - "Input(Alpha) should not be null"); - auto *alpha = ctx.Input("Alpha"); - PADDLE_ENFORCE(alpha->numel() == 1, "Size of weight Alpha must be one."); - - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), - "Output(Out) should not be null"); - auto *out = ctx.Output("Out"); - out->Resize(in->dims()); - ctx.ShareLoD("X", /*->*/ "Out"); + void InferShape(framework::InferShapeContextBase *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); + PADDLE_ENFORCE(ctx->HasInput("Alpha"), "Input(Alpha) should not be null"); + PADDLE_ENFORCE(product(ctx->GetInputDim("Alpha")) == 1, + "Size of weight Alpha must be one."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null"); + ctx->SetOutputDim("Out", ctx->GetInputDim("X")); + ctx->ShareLoD("X", /*->*/ "Out"); } }; @@ -68,19 +63,13 @@ class PReluGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), - "Input(Out@GRAD) should not be null"); - auto *dx = ctx.Output(framework::GradVarName("X")); - auto *x = ctx.Input("X"); - - auto *dalpha = - ctx.Output(framework::GradVarName("Alpha")); - auto *alpha = ctx.Input("Alpha"); - - dx->Resize(x->dims()); - dalpha->Resize(alpha->dims()); + void InferShape(framework::InferShapeContextBase *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null"); + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + ctx->SetOutputDim(framework::GradVarName("Alpha"), + ctx->GetInputDim("Alpha")); } }; diff --git a/paddle/operators/rank_loss_op.cc b/paddle/operators/rank_loss_op.cc index 39af08c8751c3b95cf5fdef7395186a0176a20a2..1ba22006f27abc963e7f161636a964863513a40c 100644 --- a/paddle/operators/rank_loss_op.cc +++ b/paddle/operators/rank_loss_op.cc @@ -25,22 +25,21 @@ class RankLossOp : public framework::OperatorWithKernel { : OperatorWithKernel(type, inputs, outputs, attrs) {} protected: - void InferShape(const framework::InferShapeContext &ctx) const override { + void InferShape(framework::InferShapeContextBase *ctx) const override { // input check - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), - "Input(Label) shouldn't be null"); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Left"), - "Input(Left) shouldn't be null"); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Right"), - "Input(Right) shouldn't be null"); - auto label_dims = ctx.Input("Label")->dims(); - auto left_dims = ctx.Input("Left")->dims(); - auto right_dims = ctx.Input("Right")->dims(); + PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) shouldn't be null"); + PADDLE_ENFORCE(ctx->HasInput("Left"), "Input(Left) shouldn't be null"); + PADDLE_ENFORCE(ctx->HasInput("Right"), "Input(Right) shouldn't be null"); + + auto label_dims = ctx->GetInputDim("Label"); + auto left_dims = ctx->GetInputDim("Left"); + auto right_dims = ctx->GetInputDim("Right"); + PADDLE_ENFORCE((label_dims == left_dims) && (left_dims == right_dims), "All inputs must have the same size"); PADDLE_ENFORCE((label_dims.size() == 2) && (label_dims[1] == 1), "All inputs must be row vector with size batch_size x 1."); - ctx.Output("Out")->Resize(label_dims); + ctx->SetOutputDim("Out", label_dims); } }; @@ -91,25 +90,22 @@ class RankLossGradOp : public framework::OperatorWithKernel { : OperatorWithKernel(type, inputs, outputs, attrs) {} protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), - "Input(Label) shouldn't be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Left"), - "Input(Left) shouldn't be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Right"), - "Input(Right) shouldn't be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), - "Input(Out@GRAD) shouldn't be null."); - auto dims = ctx.Input("Left")->dims(); - auto *left_grad = - ctx.Output(framework::GradVarName("Left")); - auto *right_grad = - ctx.Output(framework::GradVarName("Right")); - if (left_grad) { - left_grad->Resize(dims); + void InferShape(framework::InferShapeContextBase *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) shouldn't be null."); + PADDLE_ENFORCE(ctx->HasInput("Left"), "Input(Left) shouldn't be null."); + PADDLE_ENFORCE(ctx->HasInput("Right"), "Input(Right) shouldn't be null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) shouldn't be null."); + auto dims = ctx->GetInputDim("Left"); + auto left_grad_name = framework::GradVarName("Left"); + auto right_grad_name = framework::GradVarName("Right"); + + if (ctx->HasOutput(left_grad_name)) { + ctx->SetOutputDim(left_grad_name, dims); } - if (right_grad) { - right_grad->Resize(dims); + + if (ctx->HasOutput(right_grad_name)) { + ctx->SetOutputDim(right_grad_name, dims); } } }; diff --git a/paddle/operators/reshape_op.cc b/paddle/operators/reshape_op.cc index ddb93007e21e4d1ae4be3650019c8bc6a680252d..a3c3fa2716ad9f6487e3eff2d98b2c76d964ddef 100644 --- a/paddle/operators/reshape_op.cc +++ b/paddle/operators/reshape_op.cc @@ -26,14 +26,14 @@ class ReshapeOp : public framework::OperatorWithKernel { : OperatorWithKernel(type, inputs, outputs, attrs) {} protected: - void InferShape(const framework::InferShapeContext &ctx) const override { + void InferShape(framework::InferShapeContextBase *ctx) const override { // input check - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), - "Input(X) of ReshapeOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), - "Output(Out) of ReshapeOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of ReshapeOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of ReshapeOp should not be null."); - auto shape = ctx.Attr>("shape"); + auto shape = ctx->Attrs().Get>("shape"); PADDLE_ENFORCE(shape.size() > 0, "Attr(shape) shouldn't be empty."); for (auto dim : shape) { PADDLE_ENFORCE(dim > 0, "Each dimension of shape must be positive."); @@ -41,8 +41,8 @@ class ReshapeOp : public framework::OperatorWithKernel { // capacity check int64_t capacity = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); - auto *in = ctx.Input("X"); - int64_t in_size = framework::product(in->dims()); + auto x_dims = ctx->GetInputDim("X"); + int64_t in_size = framework::product(x_dims); PADDLE_ENFORCE_EQ(capacity, in_size, "The size of Input(X) mismatches with Attr(shape)."); // resize output @@ -50,11 +50,11 @@ class ReshapeOp : public framework::OperatorWithKernel { std::transform(shape.begin(), shape.end(), shape_int64.begin(), [](int a) { return static_cast(a); }); auto out_dims = framework::make_ddim(shape_int64); - ctx.Output("Out")->Resize(out_dims); - if (shape[0] == in->dims()[0]) { + ctx->SetOutputDim("Out", out_dims); + if (shape[0] == x_dims[0]) { // Only pass LoD when the first dimension is equal between // output and input. - ctx.ShareLoD("X", /*->*/ "Out"); + ctx->ShareLoD("X", /*->*/ "Out"); } } }; @@ -76,7 +76,7 @@ Given a 2-D tensor X with 2 rows and 2 columns [[1, 2], [3, 4]] -with target shape = [1, 4], the reshape operator will transform +with target shape = [1, 4], the reshape operator will transform the tensor X into a 1-D tensor: [1, 2, 3, 4] @@ -94,13 +94,11 @@ class ReshapeGradOp : public framework::OperatorWithKernel { : OperatorWithKernel(type, inputs, outputs, attrs) {} protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) shouldn't be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), - "Input(Out@GRAD) shouldn't be null."); - auto dims = ctx.Input("X")->dims(); - auto *d_in = ctx.Output(framework::GradVarName("X")); - d_in->Resize(dims); + void InferShape(framework::InferShapeContextBase *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) shouldn't be null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) shouldn't be null."); + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); } }; diff --git a/paddle/operators/rowwise_add_op.cc b/paddle/operators/rowwise_add_op.cc index fc3ad721f210213491617452141dfa8834b067c0..1fcf0959dffd6a68d97dec4e2b5b509d06c0d09c 100644 --- a/paddle/operators/rowwise_add_op.cc +++ b/paddle/operators/rowwise_add_op.cc @@ -24,16 +24,16 @@ class RowwiseAddOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), - "Input(X) of RowwiseAddOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("b"), - "Input(b) of RowwiseAddOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), - "Output(Out) of RowwiseAddOp should not be null."); - - auto x_dims = ctx.Input("X")->dims(); - auto b_dims = ctx.Input("b")->dims(); + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of RowwiseAddOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("b"), + "Input(b) of RowwiseAddOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of RowwiseAddOp should not be null."); + + auto x_dims = ctx->GetInputDim("X"); + auto b_dims = ctx->GetInputDim("b"); PADDLE_ENFORCE_GT( x_dims.size(), b_dims.size(), "The rank of input `X` must be larger than the one of input `b`."); @@ -43,16 +43,17 @@ class RowwiseAddOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ( framework::slice_ddim(x_dims, num_col_dims, x_dims.size()), b_dims, "The width of two operands must be same"); - PADDLE_ENFORCE_EQ(ctx.OutputSize("Out"), 1, "The output size must be 1"); - ctx.Output("Out")->Resize(x_dims); - ctx.ShareLoD("X", /*->*/ "Out"); + PADDLE_ENFORCE_EQ(ctx->Outputs("Out").size(), 1, + "The output size must be 1"); + ctx->SetOutputDim("Out", x_dims); + ctx->ShareLoD("X", /*->*/ "Out"); } }; class RowwiseAddOpMaker : public framework::OpProtoAndCheckerMaker { public: - RowwiseAddOpMaker(framework::OpProto *proto, - framework::OpAttrChecker *op_checker) + RowwiseAddOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "The left input of row-wise add op, must be matrix"); AddInput("b", "The right input of row-wise add op, must be vector"); @@ -69,25 +70,29 @@ class RowwiseAddGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "X should not be null"); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("b"), "b should not be null"); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), - "Input(Out@GRAD) should not be null"); - auto x_dims = ctx.Input("X")->dims(); - auto b_dims = ctx.Input("b")->dims(); + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "X should not be null"); + PADDLE_ENFORCE(ctx->HasInput("b"), "b should not be null"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null"); + auto x_dims = ctx->GetInputDim("X"); + auto b_dims = ctx->GetInputDim("b"); PADDLE_ENFORCE_GT( x_dims.size(), b_dims.size(), "The rank of input `X` must be larger than the one of input `b`."); - int num_col_dims = x_dims.size() - b_dims.size(); + int64_t num_col_dims = x_dims.size() - b_dims.size(); PADDLE_ENFORCE_EQ( framework::slice_ddim(x_dims, num_col_dims, x_dims.size()), b_dims, "The width of two operands must be same"); - auto *dx = ctx.Output(framework::GradVarName("X")); - auto *db = ctx.Output(framework::GradVarName("b")); - if (dx) dx->Resize(x_dims); - if (db) db->Resize(b_dims); + auto x_grad_name = framework::GradVarName("X"); + auto b_grad_name = framework::GradVarName("b"); + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, x_dims); + } + if (ctx->HasOutput(b_grad_name)) { + ctx->SetOutputDim(b_grad_name, b_dims); + } } }; diff --git a/paddle/operators/scale_op.cc b/paddle/operators/scale_op.cc index 1ae77a9722ef1a5548a6c4100c32fdddcee8c5cd..e92501e12834b92875f494de401672344f50e3b5 100644 --- a/paddle/operators/scale_op.cc +++ b/paddle/operators/scale_op.cc @@ -26,16 +26,13 @@ class ScaleOp : public framework::OperatorWithKernel { : OperatorWithKernel(type, inputs, outputs, attrs) {} protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), - "Input(X) of ScaleOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), - "Output(Out) of ScaleOp should not be null."); - - auto *in = ctx.Input("X"); - auto *out = ctx.Output("Out"); - out->Resize(in->dims()); - ctx.ShareLoD("X", /*->*/ "Out"); + void InferShape(framework::InferShapeContextBase *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of ScaleOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of ScaleOp should not be null."); + ctx->SetOutputDim("Out", ctx->GetInputDim("X")); + ctx->ShareLoD("X", /*->*/ "Out"); } }; diff --git a/paddle/operators/scatter_op.cc b/paddle/operators/scatter_op.cc index 3f02081a060281dec533c02b346f0667da28b8c3..3fc4a39ebc5526bfed61ba667c3cdc214cdd056c 100644 --- a/paddle/operators/scatter_op.cc +++ b/paddle/operators/scatter_op.cc @@ -23,29 +23,30 @@ class ScatterOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Ref"), - "Input(Ref) of ScatterOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Index"), - "Input(Index) of ScatterOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Updates"), - "Input(Updates) of ScatterOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), - "Output(Out) of ScatterOp should not be null."); + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Ref"), + "Input(Ref) of ScatterOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Index"), + "Input(Index) of ScatterOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Updates"), + "Input(Updates) of ScatterOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of ScatterOp should not be null."); - PADDLE_ENFORCE_EQ(ctx.Input("Index")->dims().size(), 1, + auto updates_dims = ctx->GetInputDim("Updates"); + auto ref_dims = ctx->GetInputDim("Ref"); + PADDLE_ENFORCE_EQ(ctx->GetInputDim("Index").size(), 1, "Update Index should be 1-D."); - PADDLE_ENFORCE_EQ(ctx.Input("Ref")->dims().size(), - ctx.Input("Updates")->dims().size(), + PADDLE_ENFORCE_EQ(ref_dims.size(), updates_dims.size(), "Reference and Updates should have the same shape size"); - PADDLE_ENFORCE_EQ(ctx.Input("Updates")->dims()[0], - ctx.Input("Index")->dims()[0], + PADDLE_ENFORCE_EQ(ctx->GetInputDim("Updates")[0], + ctx->GetInputDim("Index")[0], "Updates and Index should have same batch-size."); - framework::DDim data_dim(ctx.Input("Updates")->dims()); - for (int i = 1; i < data_dim.size(); ++i) - PADDLE_ENFORCE_EQ(data_dim[i], ctx.Input("Updates")->dims()[i]); - ctx.Output("Out")->Resize( - ctx.Input("Ref")->dims()); + framework::DDim data_dim(updates_dims); + for (int i = 1; i < data_dim.size(); ++i) { + PADDLE_ENFORCE_EQ(data_dim[i], updates_dims[i]); + } + ctx->SetOutputDim("Out", ref_dims); } }; @@ -54,22 +55,17 @@ class ScatterGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - auto *dUpdates = - ctx.Output(framework::GradVarName("Updates")); - auto *Updates = ctx.Input("Updates"); - auto *dRef = ctx.Output(framework::GradVarName("Ref")); - auto *Ref = ctx.Input("Ref"); - - dRef->Resize(Ref->dims()); - dUpdates->Resize(Updates->dims()); + void InferShape(framework::InferShapeContextBase* ctx) const override { + ctx->SetOutputDim(framework::GradVarName("Updates"), + ctx->GetInputDim("Updates")); + ctx->SetOutputDim(framework::GradVarName("Ref"), ctx->GetInputDim("Ref")); } }; class ScatterOpMaker : public framework::OpProtoAndCheckerMaker { public: - ScatterOpMaker(framework::OpProto *proto, - framework::OpAttrChecker *op_checker) + ScatterOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("Ref", "The source input of scatter op"); AddInput("Index", @@ -77,13 +73,14 @@ class ScatterOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("Updates", "The updated value of updates op"); AddOutput("Out", "The output of add op"); AddComment(R"DOC( -Scatter Operator by selecting from the first axis, +Scatter Operator by selecting from the first axis, Out = Ref Out[Index] = Ref[Index] + Updates )DOC"); } }; + } // namespace operators } // namespace paddle diff --git a/paddle/operators/sequence_pool_op.cc b/paddle/operators/sequence_pool_op.cc index 73f9cb879a2ef690909428b3b672b12717a6a02c..17685ea654715f6996e17f6228f266c3aa1ee424 100644 --- a/paddle/operators/sequence_pool_op.cc +++ b/paddle/operators/sequence_pool_op.cc @@ -22,23 +22,12 @@ class SequencePoolOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext& ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), - "Input(X) of SequencePoolOp should not be null."); - PADDLE_ENFORCE_NOT_NULL( - ctx.OutputVar("Out"), - "Output(Out) of SequencePoolOp should not be null."); - - auto* x = ctx.Input("X"); - auto dims = x->dims(); - auto lod = x->lod(); - PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now."); - PADDLE_ENFORCE_GE( - dims[0], - /*batch size = */ static_cast(lod[0].size() - 1), - "The first dimension of Input(X) must be large than batch size."); - dims[0] = lod[0].size() - 1; - ctx.Output("Out")->Resize({dims}); + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of SequenceAvgPoolOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of SequenceAvgPoolOp should not be null."); + ctx->SetOutputDim("Out", ctx->GetInputDim("X")); } }; @@ -61,17 +50,17 @@ class SequencePoolOpMaker : public framework::OpProtoAndCheckerMaker { SequencePoolOp pools features of all time-steps of each instance. For a mini-batch of 3 variable lengths sentences, containing 2, 3, and 2 time-steps: - + Assume X is a [7,M,N] float LoDTensor, and X->lod()[0] = [0, 2, 5, 7]. - Besides, for the sake of simplicity, we assume M=1 and N=1, + Besides, for the sake of simplicity, we assume M=1 and N=1, and the value of X = [[1, 3], [2, 4, 6], [5, 1]]. Thus, Out is a [3,1,1] float LoDTensor, but Out->lod() is nullptr. - And for different strategy, the value of Out is as follows: + And for different strategy, the value of Out is as follows: - AVERAGE: [2, 4, 3], where 2=(1+3)/2, 4=(2+4+6)/3, 3=(5+1)/2 - SUM: [4, 12, 6], where 4=1+3, 12=2+4+6, 6=5+1 - - SQRT: [2.82, 6.93, 4.24], where 2.82=(1+3)/sqrt(2), + - SQRT: [2.82, 6.93, 4.24], where 2.82=(1+3)/sqrt(2), 6.93=(2+4+6)/sqrt(3), 4.24=(5+1)/sqrt(2) - MAX: [3, 6, 5], where 3=max(1,3), 6=max(2,4,6), 5=max(5,1) - LAST: [3, 6, 1], where 3=last(1,3), 6=last(2,4,6), 1=last(5,1) @@ -85,22 +74,18 @@ class SequencePoolGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext& ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), - "Gradient of Out should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), - "The input X should not be null."); - auto og_dims = - ctx.Input(framework::GradVarName("Out"))->dims(); - auto x_dims = ctx.Input("X")->dims(); + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Gradient of Out should not be null."); + PADDLE_ENFORCE(ctx->HasInput("X"), "The input X should not be null."); + auto og_dims = ctx->GetInputDim(framework::GradVarName("Out")); + auto x_dims = ctx->GetInputDim("X"); PADDLE_ENFORCE_EQ(og_dims.size(), x_dims.size(), "The rank of output grad must equal to Input(X)."); for (int64_t i = 1; i < og_dims.size(); ++i) { PADDLE_ENFORCE_EQ(og_dims[i], x_dims[i], "The dimension mismatch."); } - auto* x_grad = - ctx.Output(framework::GradVarName("X")); - x_grad->Resize(x_dims); + ctx->SetOutputDim(framework::GradVarName("X"), x_dims); } }; diff --git a/paddle/operators/sequence_pool_op.h b/paddle/operators/sequence_pool_op.h index 231614b4c1cb0eb1901b1720e933aed5cbb25f77..cb80586e88f8d9e31b7b91a54f5e05ac6fa73f0f 100644 --- a/paddle/operators/sequence_pool_op.h +++ b/paddle/operators/sequence_pool_op.h @@ -46,16 +46,27 @@ class SequencePoolKernel : public framework::OpKernel { int strategy = context.Attr("strategy"); auto dims = in->dims(); - auto lod = in->lod()[0]; + auto lod = in->lod(); int64_t w = in->numel() / dims[0]; + // InferShape by lod + PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now."); + PADDLE_ENFORCE_GE( + dims[0], + /*batch size = */ static_cast(lod[0].size() - 1), + "The first dimension of Input(X) must be large than batch size."); + dims[0] = lod[0].size() - 1; + out->Resize({dims}); + + auto lod_level_0 = lod[0]; + out->mutable_data(context.GetPlace()); auto place = context.GetEigenDevice(); - for (int i = 0; i < static_cast(lod.size()) - 1; ++i) { - Tensor in_t = - in->Slice(static_cast(lod[i]), static_cast(lod[i + 1])); + for (int i = 0; i < static_cast(lod_level_0.size()) - 1; ++i) { + Tensor in_t = in->Slice(static_cast(lod_level_0[i]), + static_cast(lod_level_0[i + 1])); Tensor out_t = out->Slice(i, i + 1); - int64_t h = static_cast(lod[i + 1] - lod[i]); + int64_t h = static_cast(lod_level_0[i + 1] - lod_level_0[i]); auto in_e = EigenMatrix::From(in_t, framework::make_ddim({h, w})); auto out_e = EigenVector::Flatten(out_t); diff --git a/paddle/operators/sgd_op.cc b/paddle/operators/sgd_op.cc index b063e2427217f20eb89f7cd1af0354ad0e400feb..3bce95535cf10c0df95b503c6e362b3f0ba2e723 100644 --- a/paddle/operators/sgd_op.cc +++ b/paddle/operators/sgd_op.cc @@ -22,19 +22,18 @@ class SGDOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("param"), - "Input(param) of SGDOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("grad"), - "Input(grad) of SGDOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("param_out"), - "Output(param_out) of SGDOp should not be null."); - - PADDLE_ENFORCE_EQ(ctx.Input("param")->dims(), - ctx.Input("grad")->dims(), + void InferShape(framework::InferShapeContextBase *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("param"), + "Input(param) of SGDOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("grad"), + "Input(grad) of SGDOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("param_out"), + "Output(param_out) of SGDOp should not be null."); + + auto param_dim = ctx->GetInputDim("param"); + PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("grad"), "Two input of SGD Op's dimension must be same."); - ctx.Output("param_out") - ->Resize(ctx.Input("param")->dims()); + ctx->SetOutputDim("param_out", param_dim); } }; diff --git a/paddle/operators/smooth_l1_loss_op.cc b/paddle/operators/smooth_l1_loss_op.cc index ae6d1c80b300690b070024d6266a1b99bf2ef04f..2d197e3b1b763fa87939623d47728aab3bff7cd1 100644 --- a/paddle/operators/smooth_l1_loss_op.cc +++ b/paddle/operators/smooth_l1_loss_op.cc @@ -22,33 +22,28 @@ class SmoothL1LossOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext& ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "X must be initialized."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Y must be initialized."); - - auto* x = ctx.Input("X"); - auto* y = ctx.Input("Y"); - PADDLE_ENFORCE_EQ(x->dims(), y->dims(), - "The shape of X and Y must be the same."); - PADDLE_ENFORCE_GE(x->dims().size(), 2, + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "X must be initialized."); + PADDLE_ENFORCE(ctx->HasInput("Y"), "Y must be initialized."); + + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); + PADDLE_ENFORCE_EQ(x_dims, y_dims, "The shape of X and Y must be the same."); + PADDLE_ENFORCE_GE(x_dims.size(), 2, "The tensor rank of X must be at least 2."); - auto* inside_weight = ctx.Input("InsideWeight"); - if (inside_weight) { - auto* outside_weight = ctx.Input("OutsideWeight"); - PADDLE_ENFORCE_NOT_NULL(outside_weight, - "If weights are provided, must specify both " - "inside and outside weights."); - PADDLE_ENFORCE_EQ(inside_weight->dims(), x->dims(), + if (ctx->HasInput("InsideWeight")) { + PADDLE_ENFORCE(ctx->HasInput("OutsideWeight"), + "If weights are provided, must specify both " + "inside and outside weights."); + PADDLE_ENFORCE_EQ(ctx->GetInputDim("InsideWeight"), x_dims, "The shape of InsideWeight must be same as X."); - PADDLE_ENFORCE_EQ(outside_weight->dims(), x->dims(), + PADDLE_ENFORCE_EQ(ctx->GetInputDim("OutsideWeight"), x_dims, "The shape of OutsideWeight must be same as X."); } - auto* diff = ctx.Output("Diff"); - auto* out = ctx.Output("Out"); - diff->Resize(x->dims()); + ctx->SetOutputDim("Diff", x_dims); // loss is a two-rank tensor - out->Resize({x->dims()[0], 1}); + ctx->SetOutputDim("Out", {x_dims[0], 1}); } }; @@ -99,12 +94,9 @@ class SmoothL1LossGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext& ctx) const override { - auto in_dims = ctx.Input("X")->dims(); - auto out_dims = - ctx.Input(framework::GradVarName("Out"))->dims(); - auto* x_grad = ctx.Output(framework::GradVarName("X")); - auto* y_grad = ctx.Output(framework::GradVarName("Y")); + void InferShape(framework::InferShapeContextBase* ctx) const override { + auto in_dims = ctx->GetInputDim("X"); + auto out_dims = ctx->GetInputDim(framework::GradVarName("Out")); PADDLE_ENFORCE_GE(out_dims.size(), 2, "The tensor rank of Input(Out@Grad) should be 2."); @@ -114,8 +106,14 @@ class SmoothL1LossGradOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ(out_dims[1], 1, "The 2nd dimension of Input(Out@Grad) must be 1."); - if (x_grad) x_grad->Resize(in_dims); - if (y_grad) y_grad->Resize(in_dims); + auto x_grad_name = framework::GradVarName("X"); + auto y_grad_name = framework::GradVarName("Y"); + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, in_dims); + } + if (ctx->HasOutput(y_grad_name)) { + ctx->SetOutputDim(y_grad_name, in_dims); + } } }; diff --git a/paddle/operators/softmax_op.cc b/paddle/operators/softmax_op.cc index e15cfe485016552971924a40a172e74a90629dce..e353afee3e10247fbd5c7f4282c366cd1bc39552 100644 --- a/paddle/operators/softmax_op.cc +++ b/paddle/operators/softmax_op.cc @@ -22,22 +22,23 @@ class SoftmaxOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), - "Input(X) of SoftmaxOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Y"), - "Output(Y) of SoftmaxOp should not be null."); - - PADDLE_ENFORCE(ctx.Input("X")->dims().size() == 2UL, + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of SoftmaxOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Y"), + "Output(Y) of SoftmaxOp should not be null."); + + auto x_dims = ctx->GetInputDim("X"); + PADDLE_ENFORCE(x_dims.size() == 2UL, "The input of softmax op must be a matrix."); - ctx.Output("Y")->Resize(ctx.Input("X")->dims()); + ctx->SetOutputDim("Y", x_dims); } }; class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker { public: - SoftmaxOpMaker(framework::OpProto *proto, - framework::OpAttrChecker *op_checker) + SoftmaxOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "The input tensor of softmax. " @@ -68,16 +69,15 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) should be not null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Y")), - "Input(Y@GRAD) should be not null."); - PADDLE_ENFORCE_EQ(ctx.Input("Y")->dims(), - ctx.Input(framework::GradVarName("Y"))->dims(), + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should be not null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")), + "Input(Y@GRAD) should be not null."); + PADDLE_ENFORCE_EQ(ctx->GetInputDim("Y"), + ctx->GetInputDim(framework::GradVarName("Y")), "Input(Y) and its gradients should have a same shape."); - ctx.Output(framework::GradVarName("X")) - ->Resize(ctx.Input("X")->dims()); + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); } }; diff --git a/paddle/operators/split_op.cc b/paddle/operators/split_op.cc index a9d35b4fb79ae83379552ae2c2b4d694bd8f86dd..8640d1010ef6ae352a93ee2fd7b771a90c6efa5c 100644 --- a/paddle/operators/split_op.cc +++ b/paddle/operators/split_op.cc @@ -24,40 +24,42 @@ class SplitOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - // infershape - auto *in = ctx.Input("X"); - auto outs = ctx.MultiOutput("Out"); - size_t axis = static_cast(ctx.Attr("axis")); - size_t num = static_cast(ctx.Attr("num")); - std::vector sections = - static_cast>(ctx.Attr>("sections")); - const size_t n = outs.size(); + void InferShape(framework::InferShapeContextBase *ctx) const override { + auto in_dims = ctx->GetInputDim("X"); + auto outs_names = ctx->Outputs("Out"); + size_t axis = static_cast(ctx->Attrs().Get("axis")); + size_t num = static_cast(ctx->Attrs().Get("num")); + std::vector sections = static_cast>( + ctx->Attrs().Get>("sections")); + const size_t outs_number = outs_names.size(); + std::vector outs_dims; + outs_dims.reserve(outs_number); if (num > 0) { - int64_t in_axis_dim = in->dims()[axis]; + int64_t in_axis_dim = in_dims[axis]; PADDLE_ENFORCE_EQ(in_axis_dim % num, 0, "tensor split does not result" " in an equal division"); size_t out_axis_dim = in_axis_dim / num; - for (size_t i = 0; i < n; ++i) { - auto dim = in->dims(); + for (size_t i = 0; i < outs_number; ++i) { + auto dim = in_dims; dim[axis] = out_axis_dim; - outs[i]->Resize(dim); + outs_dims.push_back(dim); } } else if (sections.size() > 0) { - PADDLE_ENFORCE_EQ(sections.size(), n, + PADDLE_ENFORCE_EQ(sections.size(), outs_number, "tensor split sections size" "should be equal to output size."); - for (size_t i = 0; i < n; ++i) { - auto dim = in->dims(); + for (size_t i = 0; i < outs_number; ++i) { + auto dim = in_dims; dim[axis] = sections[i]; - outs[i]->Resize(dim); + outs_dims.push_back(dim); } } else { PADDLE_ENFORCE_NOT_NULL(nullptr, "split operator should", " specify indices or sections."); } + ctx->SetOutputsDim("Out", outs_dims); } }; diff --git a/paddle/operators/squared_l2_distance_op.cc b/paddle/operators/squared_l2_distance_op.cc index 33a564b05b1b490c6d23b7d17cef45b7740dfa39..5a0cb596008a98aacf5e7b5ff70307ea1b8508e6 100644 --- a/paddle/operators/squared_l2_distance_op.cc +++ b/paddle/operators/squared_l2_distance_op.cc @@ -22,24 +22,19 @@ class SquaredL2DistanceOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext& ctx) const override { - PADDLE_ENFORCE_NOT_NULL( - ctx.InputVar("X"), - "Input(X) of SquaredL2DistanceOp should not be null."); - PADDLE_ENFORCE_NOT_NULL( - ctx.InputVar("Y"), - "Input(Y) of SquaredL2DistanceOp should not be null."); - PADDLE_ENFORCE_NOT_NULL( - ctx.OutputVar("sub_result"), + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of SquaredL2DistanceOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Y"), + "Input(Y) of SquaredL2DistanceOp should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("sub_result"), "Output(sub_result) of SquaredL2DistanceOp should not be null."); - PADDLE_ENFORCE_NOT_NULL( - ctx.OutputVar("Out"), - "Output(Out) of SquaredL2DistanceOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of SquaredL2DistanceOp should not be null."); - auto* x = ctx.Input("X"); - auto x_dims = x->dims(); - auto* y = ctx.Input("Y"); - auto y_dims = y->dims(); + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); PADDLE_ENFORCE_EQ(framework::arity(x_dims), framework::arity(y_dims), "Tensor rank of both SquaredL2DistanceOp's " @@ -47,17 +42,16 @@ class SquaredL2DistanceOp : public framework::OperatorWithKernel { int rank = framework::arity(x_dims); PADDLE_ENFORCE_GE(rank, 2, "Tensor rank should be at least equal to 2."); - PADDLE_ENFORCE_EQ(x->numel() / x_dims[0], y->numel() / y_dims[0], + PADDLE_ENFORCE_EQ(product(x_dims) / x_dims[0], product(y_dims) / y_dims[0], "Product of dimensions expcet the first dimension of " "input and target must be equal."); PADDLE_ENFORCE(y_dims[0] == 1 || y_dims[0] == x_dims[0], "First dimension of target must be equal to input " "or to 1."); - ctx.Output("sub_result") - ->Resize({x_dims[0], x->numel() / x_dims[0]}); - ctx.Output("Out")->Resize({x_dims[0], 1}); - ctx.ShareLoD("X", /*->*/ "Out"); + ctx->SetOutputDim("sub_result", {x_dims[0], product(x_dims) / x_dims[0]}); + ctx->SetOutputDim("Out", {x_dims[0], 1}); + ctx->ShareLoD("X", /*->*/ "Out"); } }; @@ -92,22 +86,22 @@ class SquaredL2DistanceGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext& ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), - "Gradient of Out should not be null"); - auto out_dims = ctx.Input(framework::GradVarName("Out"))->dims(); - auto x_dims = ctx.Input("X")->dims(); - auto y_dims = ctx.Input("Y")->dims(); + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Gradient of Out should not be null"); + auto out_dims = ctx->GetInputDim(framework::GradVarName("Out")); + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); PADDLE_ENFORCE_EQ(out_dims[0], x_dims[0], "First dimension of output gradient and " "input value must be equal."); PADDLE_ENFORCE_EQ(out_dims[1], 1, "Second dimension of output gradient " "must be 1."); - auto* x_grad = ctx.Output(framework::GradVarName("X")); - auto* y_grad = ctx.Output(framework::GradVarName("Y")); - if (x_grad) x_grad->Resize(x_dims); - if (y_grad) y_grad->Resize(y_dims); + auto x_grad_name = framework::GradVarName("X"); + auto y_grad_name = framework::GradVarName("Y"); + if (ctx->HasOutput(x_grad_name)) ctx->SetOutputDim(x_grad_name, x_dims); + if (ctx->HasOutput(y_grad_name)) ctx->SetOutputDim(y_grad_name, y_dims); } }; diff --git a/paddle/operators/sum_op.cc b/paddle/operators/sum_op.cc index 437fc262f359525045a4d772ee2c204ef571caa7..8f62a9f4db8d39edc11949df513aebf4fa257d45 100644 --- a/paddle/operators/sum_op.cc +++ b/paddle/operators/sum_op.cc @@ -21,31 +21,27 @@ class SumOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE(!ctx.MultiInputVar("X").empty(), - "Input(X) of SumOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), - "Output(Out) of SumOp should not be null."); - - auto ins = ctx.MultiInput("X"); - auto *out = ctx.Output("Out"); - int N = ins.size(); - - auto in_dim = ins[0]->dims(); + void InferShape(framework::InferShapeContextBase* ctx) const override { + auto x_dims = ctx->GetInputsDim("X"); + PADDLE_ENFORCE(!x_dims.empty(), "Input(X) of SumOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of SumOp should not be null."); + auto in_dim = x_dims[0]; + size_t N = x_dims.size(); PADDLE_ENFORCE_GT(N, 1, "Input tensors count should > 1."); - for (int i = 1; i < N; i++) { - auto dim = ins[i]->dims(); + for (size_t i = 1; i < N; i++) { + auto dim = x_dims[i]; PADDLE_ENFORCE(in_dim == dim, "Input tensors must have same shape"); } - out->Resize(in_dim); - ctx.ShareLoD("X", /*->*/ "Out"); + ctx->SetOutputDim("Out", in_dim); + ctx->ShareLoD("X", /*->*/ "Out"); } }; class SumOpMaker : public framework::OpProtoAndCheckerMaker { public: - SumOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + SumOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "the input tensors of sum operator.").AsDuplicable(); AddOutput("Out", "the output tensor of sum operator."); @@ -63,13 +59,16 @@ class SumGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - auto outputs = - ctx.MultiOutput(framework::GradVarName("X")); - auto dims = ctx.Input(framework::GradVarName("Out"))->dims(); - for (auto output : outputs) { - output->Resize(dims); + void InferShape(framework::InferShapeContextBase* ctx) const override { + auto out_grad_dims = ctx->GetInputDim(framework::GradVarName("Out")); + auto x_grad_names = ctx->Outputs(framework::GradVarName("X")); + size_t x_length = x_grad_names.size(); + std::vector x_grad_dims; + x_grad_dims.reserve(x_length); + for (size_t i = 0; i < x_length; ++i) { + x_grad_dims.push_back(out_grad_dims); } + ctx->SetOutputsDim(framework::GradVarName("X"), x_grad_dims); } }; diff --git a/paddle/operators/top_k_op.cc b/paddle/operators/top_k_op.cc index a6e43964e9825cd1ced9e7c1bc8d691422248fee..5f22bf1df8720b60aba7cd75896d88cd1ad77635 100644 --- a/paddle/operators/top_k_op.cc +++ b/paddle/operators/top_k_op.cc @@ -22,26 +22,26 @@ class TopkOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), - "Input(X) of TopkOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), - "Output(Out) of TopkOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Indices"), - "Output(Indices) of TopkOp should not be null."); + void InferShape(framework::InferShapeContextBase *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of TopkOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of TopkOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Indices"), + "Output(Indices) of TopkOp should not be null."); - auto *input = ctx.Input("X"); - const int k = static_cast(ctx.Attr("k")); + auto input_dims = ctx->GetInputDim("X"); + const int k = static_cast(ctx->Attrs().Get("k")); PADDLE_ENFORCE_GE(k, 1, "k must >= 1"); - PADDLE_ENFORCE_GE(input->dims().size(), 1, "input must have >= 1d shape"); - PADDLE_ENFORCE_GE(input->dims()[input->dims().size() - 1], k, + PADDLE_ENFORCE_GE(input_dims.size(), 1, "input must have >= 1d shape"); + PADDLE_ENFORCE_GE(input_dims[input_dims.size() - 1], k, "input must have >= k columns"); - framework::DDim dims = input->dims(); + framework::DDim dims = input_dims; dims[dims.size() - 1] = k; - ctx.Output("Out")->Resize(dims); - ctx.Output("Indices")->Resize(dims); + ctx->SetOutputDim("Out", dims); + ctx->SetOutputDim("Indices", dims); } }; diff --git a/paddle/operators/transpose_op.cc b/paddle/operators/transpose_op.cc index 017a05326e9b397185d7c3530891884b11784783..0672f9342dac00ecc3f358450a9a203327cbb51f 100644 --- a/paddle/operators/transpose_op.cc +++ b/paddle/operators/transpose_op.cc @@ -24,12 +24,11 @@ class TransposeOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null"); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), - "Output(Out) should not be null"); - auto x_dims = ctx.Input("X")->dims(); - std::vector axis = ctx.Attr>("axis"); + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); + PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null"); + auto x_dims = ctx->GetInputDim("X"); + std::vector axis = ctx->Attrs().Get>("axis"); size_t x_rank = x_dims.size(); size_t axis_size = axis.size(); @@ -51,14 +50,14 @@ class TransposeOp : public framework::OperatorWithKernel { for (size_t i = 0; i < axis_size; i++) { out_dims[i] = x_dims[axis[i]]; } - ctx.Output("Out")->Resize(out_dims); + ctx->SetOutputDim("Out", out_dims); } }; class TransposeOpMaker : public framework::OpProtoAndCheckerMaker { public: - TransposeOpMaker(framework::OpProto *proto, - framework::OpAttrChecker *op_checker) + TransposeOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput( "X", @@ -79,7 +78,7 @@ For example: [3, 4, 5]]) >> axis = [1, 0] >> output = input.transpose(axis) - >> output + >> output array([[0, 3], [1, 4], [2, 5]]) @@ -94,14 +93,15 @@ class TransposeOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null"); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), - "Input(Out@GRAD) should not be null"); - auto x_dims = ctx.Input("X")->dims(); - auto *x_grad = ctx.Output(framework::GradVarName("X")); - - if (x_grad) x_grad->Resize(x_dims); + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null"); + auto x_dims = ctx->GetInputDim("X"); + ctx->SetOutputDim(framework::GradVarName("X"), x_dims); + if (ctx->HasOutput(framework::GradVarName("X"))) { + ctx->SetOutputDim(framework::GradVarName("X"), x_dims); + } } }; diff --git a/paddle/operators/uniform_random_op.cc b/paddle/operators/uniform_random_op.cc index 17ea48361bc597ccfeb80884d51900e6567aa057..2771df56086ff261728af84edcdf01cda3e45e9f 100644 --- a/paddle/operators/uniform_random_op.cc +++ b/paddle/operators/uniform_random_op.cc @@ -23,18 +23,18 @@ namespace operators { template class CPUUniformRandomKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& context) const override { - auto* tensor = context.Output("Out"); - T* data = tensor->mutable_data(context.GetPlace()); - unsigned int seed = static_cast(context.Attr("seed")); + void Compute(const framework::ExecutionContext& ctx) const override { + auto* tensor = ctx.Output("Out"); + T* data = tensor->mutable_data(ctx.GetPlace()); + unsigned int seed = static_cast(ctx.Attr("seed")); std::minstd_rand engine; if (seed == 0) { seed = std::random_device()(); } engine.seed(seed); std::uniform_real_distribution dist( - static_cast(context.Attr("min")), - static_cast(context.Attr("max"))); + static_cast(ctx.Attr("min")), + static_cast(ctx.Attr("max"))); int64_t size = tensor->numel(); for (int64_t i = 0; i < size; ++i) { data[i] = dist(engine); @@ -47,21 +47,20 @@ class UniformRandomOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext& ctx) const override { - PADDLE_ENFORCE_NOT_NULL( - ctx.OutputVar("Out"), - "Output(Out) of UniformRandomOp should not be null."); + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of UniformRandomOp should not be null."); - PADDLE_ENFORCE(Attr("min") < Attr("max"), - "uniform_random's min must less then max"); - auto* tensor = ctx.Output("Out"); + PADDLE_ENFORCE( + ctx->Attrs().Get("min") < ctx->Attrs().Get("max"), + "uniform_random's min must less then max"); auto dims = Attr>("dims"); std::vector temp; temp.reserve(dims.size()); for (auto dim : dims) { temp.push_back(static_cast(dim)); } - tensor->Resize(framework::make_ddim(temp)); + ctx->SetOutputDim("Out", framework::make_ddim(temp)); } };