diff --git a/doc/howto/dev/new_op_en.md b/doc/howto/dev/new_op_en.md index b7aa501db9e5c7378398fad48503f82bff893b60..60681cdd718547e1abc730ea720f05bbd39561f1 100644 --- a/doc/howto/dev/new_op_en.md +++ b/doc/howto/dev/new_op_en.md @@ -1,14 +1,17 @@ # How to write a new operator - - [Background](#Background) - - [Implementing C++ Types](#Implementing_C++_Types) - - [Defining ProtoMaker](#Defining_ProtoMaker) - - [Defining Operator](#Defining_Operator) - - [Registering Operator](#Registering_Operator) - - [Compilation](#Compilation) - - [Python Binding](#Python_Binding) - - [Unit Tests](#Unit_Tests) - + - [Background](#background) + - [Implementing C++ Types](#implementing-c++-types) + - [Defining ProtoMaker](#defining-protoMaker) + - [Defining Operator](#defining-operator) + - [Registering Operator](#registering-operator) + - [Compilation](#compilation) + - [Python Binding](#python-binding) + - [Unit Tests](#unit-tests) + - [Testing Forward Operators](#testing-forward-operators) + - [Testing Backward Operators](#testing-backward-operators) + - [Compiling and Running](#compiling-and-running) + - [Remarks](#remarks) ## Background Here are the base types needed. For details, please refer to the design docs. @@ -232,4 +235,122 @@ The system will automatically bind to Python and link it to a generated library. ## Unit Tests -Unit tests include comparing a forward operator's implementations on different devices, comparing a backward operator's implementation on different devices, and a scaling test for the backward operator. Here, we introduce the [unit tests for `MulOp`](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/v2/framework/tests/test_mul_op.py). +Unit tests for an operator include + +1. comparing a forward operator's implementations on different devices, + +2. comparing a backward operator's implementation on different devices, and + +3. a scaling test for the backward operator. + +Here, we introduce the [unit tests for `MulOp`](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/v2/framework/tests/test_mul_op.py). + +### Testing Forward Operators + +A forward operator unit test inherits `unittest.TestCase` and defines metaclass `__metaclass__ = OpTestMeta`. More concrete tests are performed in `OpTestMeta`. Testing a forward operator requires the following: + +1. Defining input, output and relevant attributes in `setUp` method. + +2. Generating random input data. + +3. Implementing the same computation logic in a Python script: + + ```python + import unittest + import numpy as np + from gradient_checker import GradientChecker, create_op + from op_test_util import OpTestMeta + + class TestMulOp(unittest.TestCase): + __metaclass__ = OpTestMeta + + def setUp(self): + self.type = "mul" + self.inputs = { + 'X': np.random.random((32, 84)).astype("float32"), + 'Y': np.random.random((84, 100)).astype("float32") + } + self.outputs = {'Out': np.dot(self.inputs['X'], self.inputs['Y'])} + ``` +Get its output, and compare it with the forward operator's own output. + +The code above first loads required packages. In addition, we have + +- `self.type = "mul" ` defines the type that is identical to what the operator's registered type. +- `self.inputs` defines input, with type `numpy.array` and initializes it. +- `self.outputs` defines output and completes the same operator computation in the Python script, and returns its result from the Python script. + +### Testing Backward Operators + +A backward operator unit test inherits `GradientChecker`, which inherits `unittest.TestCase`. As a result, **a backward operator unit test needs to be have the prefix `test_`**. + +```python +class TestMulGradOp(GradientChecker): + def setUp(self): + self.op = create_op("mul") + self.inputs = { + 'X': np.random.random((32, 84)).astype("float32"), + 'Y': np.random.random((84, 100)).astype("float32") + } + + def test_cpu_gpu_compare(self): + self.compare_grad(self.op, self.inputs) + + def test_normal(self): + # mul op will enlarge the relative error + self.check_grad( + self.op, self.inputs, ["X", "Y"], "Out", max_relative_error=0.5) + + def test_ignore_x(self): + self.check_grad( + self.op, + self.inputs, ["Y"], + "Out", + max_relative_error=0.5, + no_grad_set={"X"}) + + def test_ignore_y(self): + self.check_grad( + self.op, + self.inputs, ["X"], + "Out", + max_relative_error=0.5, + no_grad_set={"Y"}) +``` + +Some key points in the code above include: + +- `create_op("mul")` creates the backward operator's corresponding forward operator. +- `compare_grad` compares results between utilizing the CPU and the GPU. +- `test_normal` calls `check_grad` to validate scaling tests' correctness and stability through numeric methods. + - The first variable `self.op` denotes the forward operator. + - The second variable `self.inputs` denotes the input dictionary, which has its key value identical to its `ProtoMaker` definitions. + - The third variable `["X", "Y"]` appoints `X` and `Y` to be scale tested. + - The fourth variable `"Out"` points to the network's final output target `Out`. +- `test_ignore_x` and `test_ignore_y`branches test the cases where there is only one scaling input. + +### Compiling and Running + + +Any new unit testing file of the format `test_*.py` added to the director `python/paddle/v2/framework/tests` is automatically added to the project to compile. + +Note that **unlike the compile test for Ops, running unit tests requires compiling the entire project** and requires compiling with flag `WITH_TESTING` on i.e. `cmake paddle_dir -DWITH_TESTING=ON`. + +After successfully compiling the project, run the following command to run unit tests: + +```bash +make test ARGS="-R test_mul_op -V" +``` + +Or, + +```bash +ctest -R test_mul_op +``` + +## Remarks + +- Every `*_op.h` (if applicable), `*_op.cc`, and `*_op.cu` (if applicable) must be created for a unique Op. Compiling will fail if multiple operators are included per file. +- The type with which an operator is registered needs to be identical to the Op's name. Registering `REGISTER_OP(B, ...)` in `A_op.cc` will cause unit testing failures. +- If the operator does not implement a GPU kernel, please refrain from creating an empty `*_op.cu` file, or else unit tests will fail. +- If multiple operators rely on some shared methods, a file NOT named `*_op.*` can be created to store them, such as `gather.h`. diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index e535f84dba7c2726fbb70fa11ca8e9e2d29b8665..5b0c18cc6c69f683d12ac6fa47ce1b8c7d1fc038 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -26,7 +26,7 @@ cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry) cc_library(grad_op_builder SRCS grad_op_builder.cc DEPS operator) -cc_library(op_registry SRCS op_registry.cc DEPS grad_op_builder op_proto_maker) +cc_library(op_registry SRCS op_registry.cc DEPS grad_op_builder op_proto_maker op_info) cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry) cc_test(grad_op_builder_test SRCS grad_op_builder_test.cc DEPS grad_op_builder op_registry add_op) diff --git a/paddle/framework/attribute.h b/paddle/framework/attribute.h index 488fa38faf12ee51087643f79295f36bfd33ee22..c7559cefb6415ee141f32e4357459653564cd2ac 100644 --- a/paddle/framework/attribute.h +++ b/paddle/framework/attribute.h @@ -45,6 +45,21 @@ inline AttrType AttrTypeID() { Attribute GetAttrValue(const OpDesc::Attr& attr_desc); +class AttrReader { + public: + explicit AttrReader(const AttributeMap& attrs) : attrs_(attrs) {} + + template + inline const T& Get(const std::string& name) const { + PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should be in AttributeMap", + name); + return boost::get(attrs_.at(name)); + } + + private: + const AttributeMap& attrs_; +}; + // check whether a value(attribute) fit a certain limit template class GreaterThanChecker { diff --git a/paddle/framework/backward.md b/paddle/framework/backward.md index 0a6d762bc8be5201ac196b4bc6107c06d07a31d7..ac60be572419b62f4beb644ff192d413c35e19bb 100644 --- a/paddle/framework/backward.md +++ b/paddle/framework/backward.md @@ -2,7 +2,7 @@ ## Motivation -In Neural Network, many model is solved by the the backpropagation algorithm(known as BP) at present. Technically it caculates the gradient of the loss function, then distributed back through the networks. Follows the chain rule, so we need a module chains the gradient operators/expressions together with to construct the backward pass. Every forward network needs a backward network to construct the full computation graph, the operator/expression's backward pass will be generated respect to forward pass. +In Neural Network, most models are solved by the backpropagation algorithm(known as **BP**) at present. Technically, BP calculates the gradient of the loss function, then propagates it back through the networks following the chain rule. Hence we need a module that chains the gradient operators/expressions together to construct the backward pass. Every forward network needs a backward network to construct the full computation graph. The operator/expression's backward pass will be generated with respect to the forward pass. ## Implementation @@ -24,9 +24,9 @@ A backward network is built up with several backward operators. Backward operato | **Operator::inputs_** | Inputs | Inputs, Outputs, OutputGradients | | **Operator::outputs_** | Outputs | InputGradients | - In most cases, there is a one-to-one correspondence between the forward and backward operators. These correspondences are recorded by a global hash map(`OpInfoMap`). To follow the philosophy of minimum core and make operators pluggable, the registry mechanism is introduced. + In most cases, there is a one-to-one relation between the forward and backward operators. These relations are recorded by a global hash map(`OpInfoMap`). To follow the philosophy of minimum core and to make operators pluggable, the registry mechanism is introduced. -For example, we have got a `mul_op`, and we can register its information and corresponding backward operator by the following macro: +For example, we have `mul_op`, and we can register its information and corresponding backward operator by the following macro: ```cpp REGISTER_OP(mul, MulOp, MulOpMaker, mul_grad, MulOpGrad); @@ -48,7 +48,7 @@ The function `BuildGradOp` will sequentially execute following processes: 1. Get the `type_` of given forward operator, and then get the corresponding backward operator's type by looking up the `OpInfoMap`. -2. Build two maps named `inputs` and `outputs` to temporary storage backward operator's inputs and outputs. Copy forward operator's `inputs_` and `outputs_` to map `inputs`, except these, are not necessary for gradient computing. +2. Build two maps named `inputs` and `outputs` to temporarily store backward operator's inputs and outputs. Copy forward operator's `inputs_` and `outputs_` to map `inputs`, except these, are not necessary for gradient computing. 3. Add forward inputs' gradient variables into map `output`, adding forward outputs' gradient variables into map `input`. @@ -56,11 +56,11 @@ The function `BuildGradOp` will sequentially execute following processes: ### Backward Network Building -A backward network is a series of backward operators. The main idea of building a backward network is creating backward operators in the inverted sequence and append them together one by one. There is some corner case need to process specially. +A backward network is a series of backward operators. The main idea of building a backward network is creating backward operators in the inverted sequence and appending them together one by one. There are some corner cases that need special processing. 1. Op - When the input forward network is an Op, return its gradient Operator Immediately. If all of its outputs are in no gradient set, then return a special `NOP`. + When the input forward network is an Op, return its gradient Operator immediately. If all of its outputs are in no gradient set, then return a special `NOP`. 2. NetOp @@ -68,33 +68,33 @@ A backward network is a series of backward operators. The main idea of building 3. RnnOp - RnnOp is a nested stepnet operator. Backward module need to recusively call `Backward` for every stepnet. + RnnOp is a nested stepnet operator. Backward module needs to recusively call `Backward` for every stepnet. 4. Sharing Variables - **sharing variables**. As illustrated in the pictures, two operator's share the same variable name of W@GRAD, which will overwrite their sharing input variable. + As illustrated in the figure 1 and figure 2, two operators share the same variable name **W@GRAD**, which will overwrite their shared input variable.


-​ pic 1. Sharing variables in operators. +​ Figure 1. Sharing variables in operators.

-​ Sharing variable between operators or same input variable used in multiple operators leads to a duplicate gradient variable. As demo show above, we need to rename gradient name recursively and add a generic add operator to replace the overwrite links. +​ Sharing variable between operators or same input variable used in multiple operators can lead to duplicate gradient variables. As illustrated in figure 2, we need to rename the gradient names recursively and add a generic add operator to prevent overwriting.


-​ pic 2. Replace sharing variable's gradient with `Add` operator. +​ Figure 2. Replace sharing variable's gradient with `Add` operator.

-​ Because our framework finds variables accord to their names, we need to rename the output links. We add a suffix of number to represent its position in clockwise. +​ Because the framework finds variables according to their names, we need to rename the output links. We add an integer suffix to represent its position in the clockwise direction. -5. Part of Gradient is Zero. +5. Part of the Gradient is Zero. - In the whole graph, there is some case of that one operator's gradient is not needed, but its input's gradient is a dependency link of other operator, we need to fill a same shape gradient matrix in the position. In our implement, we insert a special `fillZeroLike` operator. + In the whole graph, there is some case of that one operator's gradient is not needed, but its input's gradient is a dependency link of other operator, we need to fill a same shape gradient matrix in the position. In our implementation, we insert a special `fillZeroLike` operator. Follow these rules above, then collect the sub graph `OutputGradients`/`InputGradients` as the NetOp's and return it. diff --git a/paddle/framework/op_registry_test.cc b/paddle/framework/op_registry_test.cc index e00c6e8d904508ec9985537fc703c7c61a14e0de..b8fdf69683e645d991cf8dc2297b486680445a00 100644 --- a/paddle/framework/op_registry_test.cc +++ b/paddle/framework/op_registry_test.cc @@ -174,4 +174,4 @@ TEST(OpRegistry, CustomChecker) { op->Run(scope, dev_ctx); int test_attr = op->Attr("test_attr"); ASSERT_EQ(test_attr, 4); -} \ No newline at end of file +} diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index fcbfc3e4377edd0ea55c8d4328c325fa18663001..a3f28339aa64c6bde3fcefdae8b0973a5bbdd585 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -14,7 +14,6 @@ limitations under the License. */ #include "paddle/framework/operator.h" #include -#include "paddle/framework/op_registry.h" namespace paddle { namespace framework { @@ -33,6 +32,24 @@ ExecutionContext::GetEigenDevice() const { } #endif +const Tensor* GetTensorFromVar(const Variable* var) { + if (var->IsType()) { + return &var->Get(); + } + PADDLE_ENFORCE(var->IsType(), + "The Input must be LoDTensor or Tensor."); + return &var->Get(); +} + +Tensor* GetTensorFromVar(Variable* var) { + if (var->IsType()) { + return var->GetMutable(); + } + PADDLE_ENFORCE(var->IsType(), + "The Input must be LoDTensor or Tensor."); + return var->GetMutable(); +} + std::string OperatorBase::Input(const std::string& name) const { auto& ins = Inputs(name); PADDLE_ENFORCE_LE(ins.size(), 1UL, diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 2d6d5510ef6dc83f1a016be6ff123f0b9bcaf230..77c7c855c0ffed5032e639237b01037a990652c4 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -24,6 +24,7 @@ limitations under the License. */ #include "paddle/framework/framework.pb.h" #include "paddle/framework/lod_tensor.h" #include "paddle/framework/scope.h" +#include "paddle/framework/shape_inference.h" #include "paddle/framework/tensor.h" #include "paddle/platform/device_context.h" #include "paddle/platform/place.h" @@ -56,6 +57,9 @@ class OperatorBase; class InferShapeContext; class ExecutionContext; +extern const Tensor* GetTensorFromVar(const Variable* var); +extern Tensor* GetTensorFromVar(Variable* var); + /** * OperatorBase has the basic element that Net will call to do computation. * Only CreateOperator from OpRegistry will new Operator directly. User @@ -262,15 +266,6 @@ class InferShapeContext { return res; } - const Tensor* GetTensorFromVar(const Variable* var) const { - if (var->IsType()) { - return &var->Get(); - } - PADDLE_ENFORCE(var->IsType(), - "The Input(%s) must be LoDTensor or Tensor."); - return &var->Get(); - } - void ShareLoD(const std::string& in, const std::string& out, size_t i = 0, size_t j = 0) const { PADDLE_ENFORCE_LT(i, InputSize(in)); @@ -340,6 +335,78 @@ class ExecutionContext : public InferShapeContext { const platform::DeviceContext& device_context_; }; +class RuntimeInferShapeContext : public InferShapeContextBase { + public: + RuntimeInferShapeContext(const OperatorBase& op, const Scope& scope) + : op_(op), scope_(scope) {} + + bool HasInput(const std::string& name) const { + auto ipt = op_.Input(name); + auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt); + return var != nullptr; + } + + bool HasOutput(const std::string& name) const { + auto ipt = op_.Output(name); + auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt); + return var != nullptr; + } + + DDim GetInputDim(const std::string& name) const { + return GetDim(op_.Input(name)); + } + + void SetInputDim(const std::string& name, const DDim& dim) { + SetDim(op_.Input(name), dim); + } + + DDim GetOutputDim(const std::string& name) const { + return GetDim(op_.Output(name)); + } + + void SetOutputDim(const std::string& name, const DDim& dim) { + SetDim(op_.Output(name), dim); + } + + AttrReader Attrs() const { return AttrReader(op_.Attrs()); } + + const std::vector& Inputs(const std::string& name) const { + return op_.Inputs(name); + } + + const std::vector& Outputs(const std::string& name) const { + return op_.Outputs(name); + } + + private: + template + Tensor* GetTensor(const std::string& name) const { + Tensor* t = nullptr; + auto* var = scope_.FindVar(name); + if (!var->IsType() && !var->IsType()) { + if (Allocate) { + t = var->GetMutable(); + } else { + PADDLE_THROW("Variable(%s) should be tensor", name); + } + } else { + t = GetTensorFromVar(scope_.FindVar(name)); + } + return t; + } + + DDim GetDim(const std::string& name) const { + return GetTensor(name)->dims(); + } + + void SetDim(const std::string& name, const DDim& dim) { + GetTensor(name)->Resize(dim); + } + + const OperatorBase& op_; + const Scope& scope_; +}; + class OpKernel { public: /** @@ -383,8 +450,10 @@ class OperatorWithKernel : public OperatorBase { const VariableNameMap& outputs, const AttributeMap& attrs) : OperatorBase(type, inputs, outputs, attrs) {} + // runtime infershape void InferShape(const Scope& scope) const override { - InferShape(InferShapeContext(*this, scope)); + auto c = RuntimeInferShapeContext(*this, scope); + InferShape(&c); } void Run(const Scope& scope, @@ -406,7 +475,7 @@ class OperatorWithKernel : public OperatorBase { } protected: - virtual void InferShape(const InferShapeContext& ctx) const = 0; + virtual void InferShape(InferShapeContextBase* ctx) const = 0; }; } // namespace framework diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index 0beab0fac5b94c78121261d2661a6f969289afc4..8b4bb01a7bb80eaccee40f14fa82617505b1e2e5 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/framework/operator.h" #include "gtest/gtest.h" +#include "paddle/framework/op_info.h" #include "paddle/framework/op_registry.h" namespace paddle { @@ -114,7 +115,7 @@ class OpWithKernelTest : public OperatorWithKernel { using OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext& ctx) const override {} + void InferShape(framework::InferShapeContextBase* ctx) const override {} }; template diff --git a/paddle/framework/shape_inference.h b/paddle/framework/shape_inference.h new file mode 100644 index 0000000000000000000000000000000000000000..b07fc788124413f728c713027609d9d2d1c39538 --- /dev/null +++ b/paddle/framework/shape_inference.h @@ -0,0 +1,82 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/ddim.h" + +namespace paddle { +namespace framework { + +class InferShapeContextBase { + public: + virtual ~InferShapeContextBase() {} + virtual bool HasInput(const std::string &name) const = 0; + virtual bool HasOutput(const std::string &name) const = 0; + virtual framework::DDim GetInputDim(const std::string &name) const = 0; + std::vector GetInputsDim(const std::string &name) const { + const std::vector &names = Inputs(name); + return GetDims(names); + } + virtual void SetInputDim(const std::string &name, + const framework::DDim &dim) = 0; + void SetInputsDim(const std::string &name, + const std::vector &dims) { + auto &names = Inputs(name); + SetDims(names, dims); + } + virtual framework::DDim GetOutputDim(const std::string &name) const = 0; + std::vector GetOutputsDim(const std::string &name) const { + const std::vector &names = Outputs(name); + return GetDims(names); + } + virtual void SetOutputDim(const std::string &name, const DDim &dim) = 0; + void SetOutputsDim(const std::string &name, + const std::vector &dims) { + auto &names = Outputs(name); + SetDims(names, dims); + } + virtual AttrReader Attrs() const = 0; + virtual const std::vector &Inputs( + const std::string &name) const = 0; + virtual const std::vector &Outputs( + const std::string &name) const = 0; + // TODO(qiao) implement this function + void ShareLoD(const std::string &in, const std::string &out, size_t i = 0, + size_t j = 0) const {} + + protected: + virtual framework::DDim GetDim(const std::string &name) const = 0; + virtual void SetDim(const std::string &name, const framework::DDim &dim) = 0; + std::vector GetDims( + const std::vector &names) const { + std::vector ret; + ret.reserve(names.size()); + std::transform( + names.begin(), names.end(), std::back_inserter(ret), + [this](const std::string &name) { return this->GetDim(name); }); + return ret; + } + void SetDims(const std::vector &names, + const std::vector &dims) { + size_t length = names.size(); + PADDLE_ENFORCE_EQ(length, dims.size()); + for (size_t i = 0; i < length; ++i) { + SetDim(names[i], dims[i]); + } + } +}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/variable.md b/paddle/framework/variable.md index f44d5ea46e7ce98dd443d684ad42308496bc4179..442ef6b718b227d79ca73031efcbb55817558252 100644 --- a/paddle/framework/variable.md +++ b/paddle/framework/variable.md @@ -7,7 +7,7 @@ Variable is also known as *blob* in MxNet and Caffe2. It is the input and outpu For the flexibility of a DL system, a variable should be able to contain any typed value -- a tensor in most cases, but could also be some integer IDs or a scope of other variables in the case of RNN. -To use the minimum amount of memory, we'd like that a variable to allocate memory when it has to, or, lazy memory allocation. Let's take the following example: +To use the minimum amount of memory, we would like that a variable allocates memory only when it has to, or, lazy memory allocation. Let's take the following example: ```cpp Variable vr, v1, v2; @@ -38,7 +38,7 @@ This syntax for lazy memory allocation when we call `Randomize` and `Mult`, thos To make memory allocation lazy, we cannot assume that we know the type held by a variable at definition time. In other words, `class Variable` cannot be a template `template class Variable`. -Because we don't know the type `T`, we cannot save a `T*` as `Variable's` data member. Instead, we save an interface object `Placeholder`, who can return the pointer to the saved object via `Placeholder::Ptr()` as `void*`. +Because we don't know the type `T`, we cannot save a `T*` as `Variable's` data member. Instead, we save an interface object `Placeholder`, which can return the pointer to the saved object via `Placeholder::Ptr()` as `void*`. But anyway, Variable needs to know `T` so could it `delete(ptr)` and so could `Variable::Get` checks the expected type and the saved object's type. @@ -49,4 +49,4 @@ Because `PlaceholderImpl` knows `T`, it can save and return `typeid(T)` for the ## Conclusion -The technique type hiding utilizes C++ class templates, interface and derivation, and C++ RTTI (typeid). This combination saves us from definition something like `caffe2::TypeMata`, which takes hundreds of lines of C++ code. +The technique type hiding utilizes C++ class templates, interface and derivation, and C++ RTTI (typeid). This combination saves us from defining something like `caffe2::TypeMeta`, which takes hundreds of lines of C++ code. diff --git a/paddle/operators/accuracy_op.cc b/paddle/operators/accuracy_op.cc index 70e4f9da1221ab300e2b507a3da2f7c5da93f2e4..82010bfb53e58a0836c99c353590f4e32e25ac4a 100644 --- a/paddle/operators/accuracy_op.cc +++ b/paddle/operators/accuracy_op.cc @@ -22,25 +22,23 @@ class AccuracyOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL( - ctx.InputVar("Inference"), - "Input(Inference) of AccuracyOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), - "Input(Label) of AccuracyOp should not be null."); - PADDLE_ENFORCE_NOT_NULL( - ctx.OutputVar("Accuracy"), - "Output(Accuracy) of AccuracyOp should not be null."); + void InferShape(framework::InferShapeContextBase *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Inference"), + "Input(Inference) of AccuracyOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Label"), + "Input(Label) of AccuracyOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Accuracy"), + "Output(Accuracy) of AccuracyOp should not be null."); - auto *inference = ctx.Input("Inference"); - auto *label = ctx.Input("Label"); + auto inference_dim = ctx->GetInputDim("Inference"); + auto label_dim = ctx->GetInputDim("Label"); - PADDLE_ENFORCE_EQ(label->dims().size(), 1, "label must be a vector"); - PADDLE_ENFORCE_EQ(inference->dims()[0], label->dims()[0], + PADDLE_ENFORCE_EQ(label_dim.size(), 1, "label must be a vector"); + PADDLE_ENFORCE_EQ(inference_dim[0], label_dim[0], "inference size must be the same as label size"); - ctx.Output("Accuracy")->Resize({1}); - ctx.ShareLoD("Inference", /*->*/ "Accuracy"); + ctx->SetOutputDim("Accuracy", {1}); + ctx->ShareLoD("Inference", /*->*/ "Accuracy"); } }; diff --git a/paddle/operators/activation_op.cc b/paddle/operators/activation_op.cc index 06654702bc42cc7cf4917b00693334b1d36ce371..f77e1c572e33533ac672e3d476a7e6dad122031f 100644 --- a/paddle/operators/activation_op.cc +++ b/paddle/operators/activation_op.cc @@ -22,10 +22,9 @@ class ActivationOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - ctx.Output("Y")->Resize( - ctx.Input("X")->dims()); - ctx.ShareLoD("X", /*->*/ "Y"); + void InferShape(framework::InferShapeContextBase *ctx) const override { + ctx->SetOutputDim("Y", ctx->GetInputDim("X")); + ctx->ShareLoD("X", /*->*/ "Y"); } }; @@ -34,9 +33,8 @@ class ActivationOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - ctx.Output(framework::GradVarName("X")) - ->Resize(ctx.Input("Y")->dims()); + void InferShape(framework::InferShapeContextBase *ctx) const override { + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Y")); } }; diff --git a/paddle/operators/add_op.cc b/paddle/operators/add_op.cc index ed11d096974341022637676537793645f46738f0..3914d1323083ede6a7ea07e7b4ef76b9e4afd26d 100644 --- a/paddle/operators/add_op.cc +++ b/paddle/operators/add_op.cc @@ -22,25 +22,23 @@ class AddOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), - "Input(X) of AddOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), - "Input(Y) of AddOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), - "Output(Out) of AddOp should not be null."); + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of AddOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) of AddOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of AddOp should not be null."); - PADDLE_ENFORCE_EQ(ctx.Input("X")->dims(), - ctx.Input("Y")->dims(), + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); + PADDLE_ENFORCE_EQ(x_dims, y_dims, "Two input of Add Op's dimension must be same."); - ctx.Output("Out")->Resize( - ctx.Input("X")->dims()); + ctx->SetOutputDim("Out", x_dims); } }; class AddOpMaker : public framework::OpProtoAndCheckerMaker { public: - AddOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + AddOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "The first input of add op"); AddInput("Y", "The second input of add op"); @@ -58,7 +56,7 @@ class AddOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override {} + void InferShape(framework::InferShapeContextBase* ctx) const override {} }; } // namespace operators diff --git a/paddle/operators/clip_op.cc b/paddle/operators/clip_op.cc index e5a54bc4b226fd24337050fdd84b2de9c49f7949..316d28f174658de0e20ed9512f315da305bbb6d0 100644 --- a/paddle/operators/clip_op.cc +++ b/paddle/operators/clip_op.cc @@ -22,24 +22,24 @@ class ClipOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), - "Input(X) of ClipOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), - "Output(Out) of ClipOp should not be null."); - auto x_dims = ctx.Input("X")->dims(); - auto max = Attr("max"); - auto min = Attr("min"); + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of ClipOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of ClipOp should not be null."); + auto x_dims = ctx->GetInputDim("X"); + auto max = ctx->Attrs().Get("max"); + auto min = ctx->Attrs().Get("min"); PADDLE_ENFORCE_LT(min, max, "max should be greater than min."); - ctx.Output("Out")->Resize(x_dims); - ctx.ShareLoD("X", /*->*/ "Out"); + ctx->SetOutputDim("Out", x_dims); + ctx->ShareLoD("X", /*->*/ "Out"); } }; template class ClipOpMaker : public framework::OpProtoAndCheckerMaker { public: - ClipOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + ClipOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "(Tensor)The input of clip op." @@ -61,14 +61,13 @@ class ClipOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null"); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), - "Input(Out@GRAD) should not be null"); - auto x_dims = ctx.Input("X")->dims(); - auto *x_grad = ctx.Output(framework::GradVarName("X")); - if (x_grad != nullptr) { - x_grad->Resize(x_dims); + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null"); + auto x_dims = ctx->GetInputDim("X"); + if (ctx->HasOutput(framework::GradVarName("X"))) { + ctx->SetOutputDim(framework::GradVarName("X"), x_dims); } } }; diff --git a/paddle/operators/concat_op.cc b/paddle/operators/concat_op.cc index 07f847079e834716904dcc038d2097efd268bd3e..01cbfc33efcb4042438fbb398fbcca9457f1334f 100644 --- a/paddle/operators/concat_op.cc +++ b/paddle/operators/concat_op.cc @@ -24,31 +24,30 @@ class ConcatOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), - "Output(Out) of ConcatOp should not be null."); + void InferShape(framework::InferShapeContextBase *ctx) const override { + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of ConcatOp should not be null."); - auto ins = ctx.MultiInput("X"); - auto *out = ctx.Output("Out"); - size_t axis = static_cast(ctx.Attr("axis")); + auto ins = ctx->GetInputsDim("X"); + size_t axis = static_cast(ctx->Attrs().Get("axis")); size_t n = ins.size(); PADDLE_ENFORCE_GT(n, 1, "Input tensors count should > 1."); - auto out_dims = ins[0]->dims(); + auto out_dims = ins[0]; size_t in_zero_dims_size = out_dims.size(); for (size_t i = 1; i < n; i++) { for (size_t j = 0; j < in_zero_dims_size; j++) { if (j == axis) { - out_dims[axis] += ins[i]->dims()[j]; + out_dims[axis] += ins[i][j]; continue; } - PADDLE_ENFORCE_EQ(out_dims[j], ins[i]->dims()[j], + PADDLE_ENFORCE_EQ(out_dims[j], ins[i][j], "Input tensors should have the same " "elements except the specify axis.") } } - out->Resize(out_dims); + ctx->SetOutputDim("Out", out_dims); } }; diff --git a/paddle/operators/cond_op.cc b/paddle/operators/cond_op.cc index 8262a7a5c8c13c86c5f6c123a14fa89696358c57..1d44782b210bc0c40fd68dba29a16fa6959d6076 100644 --- a/paddle/operators/cond_op.cc +++ b/paddle/operators/cond_op.cc @@ -215,7 +215,7 @@ class CondOpProtoAndCheckerMaker : public framework::OpProtoAndCheckerMaker { AddComment(R"DOC( Sample dependent Cond Operator: Given Cond[i] as a 1/0 vector to indicate true/false -The equation is: +The equation is: Out[i] = subnet_t[i], if Cond[i] == true Out[i] = subnet_t[i], if Cond[i] == false )DOC"); diff --git a/paddle/operators/conv2d_op.cc b/paddle/operators/conv2d_op.cc index c3281db0964de6d7dd6be629fbcc55cabb9fef9d..5cc82944bb6b9a4fc5cd94cf2233ab84fc105fe7 100644 --- a/paddle/operators/conv2d_op.cc +++ b/paddle/operators/conv2d_op.cc @@ -27,27 +27,25 @@ class Conv2DOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Input"), - "Input(Input) of Conv2DOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Filter"), - "Input(Filter) of Conv2DOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Output"), - "Output(Output) of Conv2DOp should not be null."); - - auto in = ctx.Input("Input"); - auto filter = ctx.Input("Filter"); - auto out = ctx.Output("Output"); - std::vector strides = Attr>("strides"); - std::vector paddings = Attr>("paddings"); - int groups = Attr("groups"); - int input_channels = in->dims()[1]; - int output_channels = filter->dims()[0]; - - PADDLE_ENFORCE_EQ(in->dims().size(), 4, "Conv2DOp input should be 4-D."); - PADDLE_ENFORCE_EQ(filter->dims().size(), 4, - "Conv2DOp filter should be 4-D."); - PADDLE_ENFORCE_EQ(input_channels, filter->dims()[1] * groups, + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Input"), + "Input(Input) of Conv2DOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Filter"), + "Input(Filter) of Conv2DOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Output"), + "Output(Output) of Conv2DOp should not be null."); + + auto in_dims = ctx->GetInputDim("Input"); + auto filter_dims = ctx->GetInputDim("Filter"); + std::vector strides = ctx->Attrs().Get>("strides"); + std::vector paddings = ctx->Attrs().Get>("paddings"); + int groups = ctx->Attrs().Get("groups"); + int input_channels = in_dims[1]; + int output_channels = filter_dims[0]; + + PADDLE_ENFORCE_EQ(in_dims.size(), 4, "Conv2DOp input should be 4-D."); + PADDLE_ENFORCE_EQ(filter_dims.size(), 4, "Conv2DOp filter should be 4-D."); + PADDLE_ENFORCE_EQ(input_channels, filter_dims[1] * groups, "The number of input channels should be equal to filter " "channels * groups."); PADDLE_ENFORCE_EQ( @@ -55,17 +53,17 @@ class Conv2DOp : public framework::OperatorWithKernel { "The number of output channels should be divided by groups."); auto output_height = - outputSize(in->dims()[2], filter->dims()[2], paddings[0], strides[0]); + outputSize(in_dims[2], filter_dims[2], paddings[0], strides[0]); auto output_width = - outputSize(in->dims()[3], filter->dims()[3], paddings[1], strides[1]); - out->Resize( - {in->dims()[0], filter->dims()[0], output_height, output_width}); + outputSize(in_dims[3], filter_dims[3], paddings[1], strides[1]); + ctx->SetOutputDim( + "Output", {in_dims[0], filter_dims[0], output_height, output_width}); } }; class Conv2DOpMaker : public framework::OpProtoAndCheckerMaker { public: - Conv2DOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + Conv2DOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput( "Input", @@ -108,14 +106,15 @@ class Conv2DOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - auto in = ctx.Input("Input"); - auto filter = ctx.Input("Filter"); - auto d_in = ctx.Output(framework::GradVarName("Input")); - auto d_filter = - ctx.Output(framework::GradVarName("Filter")); - if (d_in) d_in->Resize(in->dims()); - if (d_filter) d_filter->Resize(filter->dims()); + void InferShape(framework::InferShapeContextBase* ctx) const override { + auto in_dims = ctx->GetInputDim("Input"); + auto filter_dims = ctx->GetInputDim("Filter"); + if (ctx->HasOutput(framework::GradVarName("Input"))) { + ctx->SetOutputDim(framework::GradVarName("Input"), in_dims); + } + if (ctx->HasOutput(framework::GradVarName("Filter"))) { + ctx->SetOutputDim(framework::GradVarName("Filter"), filter_dims); + } } }; diff --git a/paddle/operators/cos_sim_op.cc b/paddle/operators/cos_sim_op.cc index b56ee2047b811e212b4bf74bf7fbba753a6bcb11..040546f1a6fe1af6d17a5e363a11d27de88d03c2 100644 --- a/paddle/operators/cos_sim_op.cc +++ b/paddle/operators/cos_sim_op.cc @@ -24,22 +24,22 @@ class CosSimOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { + void InferShape(framework::InferShapeContextBase* ctx) const override { // notnull check - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), - "Input(X) of CosSimOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), - "Input(Y) of CosSimOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), - "Output(Out) of CosSimOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("XNorm"), - "Output(XNorm) of CosSimOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("YNorm"), - "Output(YNorm) of CosSimOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of CosSimOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Y"), + "Input(Y) of CosSimOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of CosSimOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("XNorm"), + "Output(XNorm) of CosSimOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("YNorm"), + "Output(YNorm) of CosSimOp should not be null."); // shape check - auto x_dims = ctx.Input("X")->dims(); - auto y_dims = ctx.Input("Y")->dims(); + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); PADDLE_ENFORCE_EQ(x_dims.size(), y_dims.size(), "Ranks of Input(X) and Input(Y) must be equal."); @@ -54,16 +54,16 @@ class CosSimOp : public framework::OperatorWithKernel { " just 1 (which will be broadcasted to match Input(X))."); // resize tensor - ctx.Output("Out")->Resize({x_dims[0], 1}); - ctx.Output("XNorm")->Resize({x_dims[0], 1}); - ctx.Output("YNorm")->Resize({y_dims[0], 1}); - ctx.ShareLoD("X", /*->*/ "Out"); + ctx->SetOutputDim("Out", {x_dims[0], 1}); + ctx->SetOutputDim("XNorm", {x_dims[0], 1}); + ctx->SetOutputDim("YNorm", {y_dims[0], 1}); + ctx->ShareLoD("X", /*->*/ "Out"); } }; class CosSimOpMaker : public framework::OpProtoAndCheckerMaker { public: - CosSimOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + CosSimOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "The 1st input of cos_sim op."); AddInput("Y", "The 2nd input of cos_sim op."); @@ -98,27 +98,23 @@ class CosSimOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { + void InferShape(framework::InferShapeContextBase* ctx) const override { // notnull check - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) must not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("XNorm"), - "Input(XNorm) must not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("YNorm"), - "Input(YNorm) must not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Out"), - "Input(Out) must not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), - "Input(Out@GRAD) must not be null."); + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null."); + PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) must not be null."); + PADDLE_ENFORCE(ctx->HasInput("XNorm"), "Input(XNorm) must not be null."); + PADDLE_ENFORCE(ctx->HasInput("YNorm"), "Input(YNorm) must not be null."); + PADDLE_ENFORCE(ctx->HasInput("Out"), "Input(Out) must not be null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) must not be null."); // shape check - auto x_dims = ctx.Input("X")->dims(); - auto y_dims = ctx.Input("Y")->dims(); - auto xnorm_dims = ctx.Input("XNorm")->dims(); - auto ynorm_dims = ctx.Input("YNorm")->dims(); - auto out_dims = ctx.Input("Out")->dims(); - auto out_grad_dims = - ctx.Input(framework::GradVarName("Out"))->dims(); + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); + auto xnorm_dims = ctx->GetInputDim("XNorm"); + auto ynorm_dims = ctx->GetInputDim("YNorm"); + auto out_dims = ctx->GetInputDim("Out"); + auto out_grad_dims = ctx->GetInputDim(framework::GradVarName("Out")); PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(), "Ranks of Input(X) and Input(Y) must be equal."); @@ -143,10 +139,14 @@ class CosSimOpGrad : public framework::OperatorWithKernel { "Shape of Input(Out@Grad) must be [X.Dim(0), 1]."); // resize tensor - auto *x_grad = ctx.Output(framework::GradVarName("X")); - auto *y_grad = ctx.Output(framework::GradVarName("Y")); - if (x_grad) x_grad->Resize(x_dims); - if (y_grad) y_grad->Resize(y_dims); + auto x_grad_name = framework::GradVarName("X"); + auto y_grad_name = framework::GradVarName("Y"); + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, x_dims); + } + if (ctx->HasOutput(y_grad_name)) { + ctx->SetOutputDim(y_grad_name, y_dims); + } } }; diff --git a/paddle/operators/crop_op.cc b/paddle/operators/crop_op.cc index 52a1123348b10e39bcfa1ba062c893e5f20ed862..9b2305e90e85a6f39d4c584a3251b25f67e81aca 100644 --- a/paddle/operators/crop_op.cc +++ b/paddle/operators/crop_op.cc @@ -25,16 +25,14 @@ class CropOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), - "Input(X) of CropOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), - "Output(Out) of CropOp should not be null."); - auto x_dim = ctx.Input("X")->dims(); - auto *y = ctx.Input("Y"); - auto *out = ctx.Output("Out"); - if (y == nullptr) { - auto shape = Attr>("shape"); + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of CropOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of CropOp should not be null."); + auto x_dim = ctx->GetInputDim("X"); + if (!ctx->HasInput("Y")) { + auto shape = ctx->Attrs().Get>("shape"); PADDLE_ENFORCE_EQ( int64_t(shape.size()), x_dim.size(), "Shape size should be equal to dimention size of input tensor."); @@ -42,19 +40,20 @@ class CropOp : public framework::OperatorWithKernel { for (size_t i = 0; i < shape.size(); ++i) { tensor_shape[i] = static_cast(shape[i]); } - out->Resize(framework::make_ddim(tensor_shape)); + ctx->SetOutputDim("Out", framework::make_ddim(tensor_shape)); } else { - PADDLE_ENFORCE_EQ(framework::arity(x_dim), framework::arity(y->dims()), + auto y_dim = ctx->GetInputDim("Y"); + PADDLE_ENFORCE_EQ(framework::arity(x_dim), framework::arity(y_dim), "Tensor rank of both CropOp's " "inputs must be same."); - out->Resize(y->dims()); + ctx->SetOutputDim("Out", y_dim); } } }; class CropOpMaker : public framework::OpProtoAndCheckerMaker { public: - CropOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + CropOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "The input of pad op. " @@ -78,12 +77,12 @@ class CropOpMaker : public framework::OpProtoAndCheckerMaker { Crop Operator. Crop input into output, as specified by offsets and shape. -There are two ways to set shape: +There are two ways to set shape: 1. referenc input: crop input X as shape as reference input. - The dimension of reference input should + The dimension of reference input should be as same as input X. 2. shape list: crop input X by shape described by a list. - The size of shape list should be as same as + The size of shape list should be as same as dimension size of input X. The input should be a k-D tensor(k > 0 and k < 7). As an example: @@ -94,15 +93,15 @@ Given: [0, 3, 4, 0, 0] [0, 0, 0, 0, 0]] -and +and offsets = [0, 1] and - + shape = [2, 2] -then we get +then we get Out = [[1, 2], [3, 4]] @@ -116,14 +115,14 @@ class CropOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null"); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), - "Input(Out@GRAD) should not be null"); - auto x_dims = ctx.Input("X")->dims(); - auto *x_grad = ctx.Output(framework::GradVarName("X")); - if (x_grad != nullptr) { - x_grad->Resize(x_dims); + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null"); + auto x_dims = ctx->GetInputDim("X"); + auto x_grad_name = framework::GradVarName("X"); + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, x_dims); } } }; diff --git a/paddle/operators/cross_entropy_op.cc b/paddle/operators/cross_entropy_op.cc index 2e16201e74c153888594ebe6679fb0036734dad4..26fc9b51c44d21d92851030449e116538f937846 100644 --- a/paddle/operators/cross_entropy_op.cc +++ b/paddle/operators/cross_entropy_op.cc @@ -22,33 +22,30 @@ class CrossEntropyOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should be not null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), - "Input(Label) should be not null."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Y"), - "Output(Y) should be not null."); - - auto x = ctx.Input("X"); - auto label = ctx.Input("Label"); - PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank should be 2."); - PADDLE_ENFORCE_EQ(label->dims().size(), 2, - "Input(Label)'s rank should be 2."); - PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0], + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null."); + PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null."); + PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) should be not null."); + + auto x_dims = ctx->GetInputDim("X"); + auto label_dims = ctx->GetInputDim("Label"); + PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2."); + PADDLE_ENFORCE_EQ(label_dims.size(), 2, "Input(Label)'s rank should be 2."); + PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0], "The 1st dimension of Input(X) and Input(Label) should " "be equal."); - if (ctx.Attr("softLabel")) { - PADDLE_ENFORCE_EQ(x->dims()[1], label->dims()[1], + if (ctx->Attrs().Get("softLabel")) { + PADDLE_ENFORCE_EQ(x_dims[1], label_dims[1], "If Attr(softLabel) == true, the 2nd dimension of " "Input(X) and Input(Label) should be equal."); } else { - PADDLE_ENFORCE_EQ(label->dims()[1], 1, + PADDLE_ENFORCE_EQ(label_dims[1], 1, "If Attr(softLabel) == false, the 2nd dimension of " "Input(Label) should be 1."); } - ctx.Output("Y")->Resize({x->dims()[0], 1}); - ctx.ShareLoD("X", /*->*/ "Y"); + ctx->SetOutputDim("Y", {x_dims[0], 1}); + ctx->ShareLoD("X", /*->*/ "Y"); } }; @@ -57,50 +54,45 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should be not null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), - "Input(Label) should be not null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Y")), - "Input(Y@GRAD) shoudl be not null."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar(framework::GradVarName("X")), - "Output(X@GRAD) should be not null."); - - auto x = ctx.Input("X"); - auto label = ctx.Input("Label"); - auto dy = ctx.Input(framework::GradVarName("Y")); - PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank should be 2."); - PADDLE_ENFORCE_EQ(dy->dims().size(), 2, - "Input(Y@Grad)'s rank should be 2."); - PADDLE_ENFORCE_EQ(label->dims().size(), 2, - "Input(Label)'s rank should be 2."); - PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0], + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null."); + PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")), + "Input(Y@GRAD) shoudl be not null."); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), + "Output(X@GRAD) should be not null."); + + auto x_dims = ctx->GetInputDim("X"); + auto label_dims = ctx->GetInputDim("Label"); + auto dy_dims = ctx->GetInputDim(framework::GradVarName("Y")); + PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2."); + PADDLE_ENFORCE_EQ(dy_dims.size(), 2, "Input(Y@Grad)'s rank should be 2."); + PADDLE_ENFORCE_EQ(label_dims.size(), 2, "Input(Label)'s rank should be 2."); + PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0], "The 1st dimension of Input(X) and Input(Label) should " "be equal."); - PADDLE_ENFORCE_EQ(x->dims()[0], dy->dims()[0], + PADDLE_ENFORCE_EQ(x_dims[0], dy_dims[0], "The 1st dimension of Input(X) and Input(Y@Grad) should " "be equal."); - PADDLE_ENFORCE_EQ(dy->dims()[1], 1, + PADDLE_ENFORCE_EQ(dy_dims[1], 1, "The 2nd dimension of Input(Y@Grad) should be 1."); - if (ctx.Attr("softLabel")) { - PADDLE_ENFORCE_EQ(x->dims()[1], label->dims()[1], + if (ctx->Attrs().Get("softLabel")) { + PADDLE_ENFORCE_EQ(x_dims[1], label_dims[1], "When Attr(softLabel) == true, the 2nd dimension of " "Input(X) and Input(Label) should be equal."); } else { - PADDLE_ENFORCE_EQ(label->dims()[1], 1, + PADDLE_ENFORCE_EQ(label_dims[1], 1, "When Attr(softLabel) == false, the 2nd dimension of " "Input(Label) should be 1."); } - - auto dx = ctx.Output(framework::GradVarName("X")); - dx->Resize(x->dims()); + ctx->SetOutputDim(framework::GradVarName("X"), x_dims); } }; class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker { public: - CrossEntropyOpMaker(framework::OpProto *proto, - framework::OpAttrChecker *op_checker) + CrossEntropyOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "(Tensor, default Tensor), a 2-D tensor with shape N x D, " diff --git a/paddle/operators/dropout_op.cc b/paddle/operators/dropout_op.cc index 2130eda6a42c893d8ec251a7022a0bfa44433bb7..a669b5cf00f1f4ad351486e2977bf8a76aa5bf62 100644 --- a/paddle/operators/dropout_op.cc +++ b/paddle/operators/dropout_op.cc @@ -24,25 +24,25 @@ class DropoutOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null."); - PADDLE_ENFORCE_GE(ctx.Attr("dropout_prob"), 0); - PADDLE_ENFORCE_LE(ctx.Attr("dropout_prob"), 1); - - auto dims = ctx.Input("X")->dims(); - ctx.Output("Out")->Resize(dims); - if (ctx.Attr("is_training")) { - ctx.Output("Mask")->Resize(dims); + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null."); + PADDLE_ENFORCE_GE(ctx->Attrs().Get("dropout_prob"), 0); + PADDLE_ENFORCE_LE(ctx->Attrs().Get("dropout_prob"), 1); + + auto x_dims = ctx->GetInputDim("X"); + ctx->SetOutputDim("Out", x_dims); + if (ctx->Attrs().Get("is_training") == 1) { + ctx->SetOutputDim("Mask", x_dims); } - ctx.ShareLoD("X", /*->*/ "Out"); + ctx->ShareLoD("X", /*->*/ "Out"); } }; template class DropoutOpMaker : public framework::OpProtoAndCheckerMaker { public: - DropoutOpMaker(framework::OpProto *proto, - framework::OpAttrChecker *op_checker) + DropoutOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddAttr("dropout_prob", "Probability of setting units to zero.") .SetDefault(.5f); @@ -70,27 +70,26 @@ class DropoutOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE(ctx.Attr("is_training"), - "GradOp is only callable when is_training is true"); - - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Mask"), "Mask must not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), - "Input(Out@GRAD) must not be null."); - - PADDLE_ENFORCE_GE(ctx.Attr("dropout_prob"), 0); - PADDLE_ENFORCE_LE(ctx.Attr("dropout_prob"), 1); - auto x_dims = ctx.Input("X")->dims(); - auto out_dims = ctx.Input(framework::GradVarName("Out"))->dims(); + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE_EQ(ctx->Attrs().Get("is_training"), 1, + "GradOp is only callable when is_training is true"); + + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null."); + PADDLE_ENFORCE(ctx->HasInput("Mask"), "Mask must not be null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) must not be null."); + + PADDLE_ENFORCE_GE(ctx->Attrs().Get("dropout_prob"), 0); + PADDLE_ENFORCE_LE(ctx->Attrs().Get("dropout_prob"), 1); + auto x_dims = ctx->GetInputDim("X"); + auto out_dims = ctx->GetInputDim(framework::GradVarName("Out")); PADDLE_ENFORCE_EQ(x_dims, out_dims, "Dimensions of Input(X) and Out@Grad must be the same."); - auto mask_dims = ctx.Input("Mask")->dims(); + auto mask_dims = ctx->GetInputDim("Mask"); PADDLE_ENFORCE_EQ(x_dims, mask_dims, "Dimensions of Input(X) and Mask must be the same."); - auto *x_grad = ctx.Output(framework::GradVarName("X")); - x_grad->Resize(x_dims); + ctx->SetOutputDim(framework::GradVarName("X"), x_dims); } }; diff --git a/paddle/operators/elementwise_op.h b/paddle/operators/elementwise_op.h index f224722c1bec6716e68de9da2509250f7d4b37ae..c4777a00d6781ea751123b2efffb6df8e29630b0 100644 --- a/paddle/operators/elementwise_op.h +++ b/paddle/operators/elementwise_op.h @@ -202,21 +202,20 @@ class ElementwiseOp : public framework::OperatorWithKernel { protected: using Tensor = framework::Tensor; - void InferShape(const framework::InferShapeContext& ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), - "Input(X) of elementwise op should not be null"); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), - "Input(Y) of elementwise op should not be null"); - PADDLE_ENFORCE_NOT_NULL( - ctx.OutputVar("Out"), - "Output(Out) of elementwise op should not be null."); - - auto x_dim = ctx.Input("X")->dims(); - auto y_dim = ctx.Input("Y")->dims(); + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of elementwise op should not be null"); + PADDLE_ENFORCE(ctx->HasInput("Y"), + "Input(Y) of elementwise op should not be null"); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of elementwise op should not be null."); + + auto x_dim = ctx->GetInputDim("X"); + auto y_dim = ctx->GetInputDim("Y"); PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(), "Rank of first input must >= rank of second input.") - ctx.Output("Out")->Resize(x_dim); - ctx.ShareLoD("X", /*->*/ "Out"); + ctx->SetOutputDim("Out", x_dim); + ctx->ShareLoD("X", /*->*/ "Out"); } }; @@ -234,7 +233,7 @@ must be small or equal to X's dimensions. )DOC"); AddAttr("axis", R"DOC( -When the shape(Y) does not equal the shape(X),Y will be broadcasted +When the shape(Y) does not equal the shape(X),Y will be broadcasted to match the shape of X and axis should be dimension index Y in X )DOC") .SetDefault(-1) @@ -244,7 +243,7 @@ to match the shape of X and axis should be dimension index Y in X comment_ = R"DOC( Limited elementwise {name} operator.The equation is: Out = {equation}. 1. The shape of Y should be same with X or -2. Y's shape is a subset of X. +2. Y's shape is a subset of X. Y will be broadcasted to match the shape of X and axis should be dimension index Y in X. example: @@ -284,27 +283,26 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel { using Tensor = framework::Tensor; protected: - void InferShape(const framework::InferShapeContext& ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null"); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) should not be null"); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), - "Input(Out@GRAD) should not be null"); - - auto x_dims = ctx.Input("X")->dims(); - auto y_dims = ctx.Input("Y")->dims(); - auto out_dims = ctx.Input(framework::GradVarName("Out"))->dims(); - auto* x_grad = ctx.Output(framework::GradVarName("X")); - auto* y_grad = ctx.Output(framework::GradVarName("Y")); + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); + PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null"); + + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); + auto out_dims = ctx->GetInputDim(framework::GradVarName("Out")); PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(), "Rank of first input must >= rank of second input.") - if (x_grad) { - x_grad->Resize(x_dims); + auto x_grad_name = framework::GradVarName("X"); + auto y_grad_name = framework::GradVarName("Y"); + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, x_dims); } - - if (y_grad) { - y_grad->Resize(y_dims); + if (ctx->HasOutput(y_grad_name)) { + ctx->SetOutputDim(y_grad_name, y_dims); } } }; diff --git a/paddle/operators/fill_zeros_like_op.cc b/paddle/operators/fill_zeros_like_op.cc index 761a527a5574edc779340ec595dfe1bc1964438a..e164de6584e7350283781019cc74118c2d13646e 100644 --- a/paddle/operators/fill_zeros_like_op.cc +++ b/paddle/operators/fill_zeros_like_op.cc @@ -22,15 +22,13 @@ class FillZerosLikeOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), - "Input(X) of FillZerosLikeOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Y"), - "Output(Y) of FillZerosLikeOp should not be null."); - - ctx.Output("Y")->Resize( - ctx.Input("X")->dims()); - ctx.ShareLoD("X", /*->*/ "Y"); + void InferShape(framework::InferShapeContextBase *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of FillZerosLikeOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Y"), + "Output(Y) of FillZerosLikeOp should not be null."); + ctx->SetOutputDim("Y", ctx->GetInputDim("X")); + ctx->ShareLoD("X", /*->*/ "Y"); } }; diff --git a/paddle/operators/gather_op.cc b/paddle/operators/gather_op.cc index fecd1ce2147a1e6f2f7928266be74ed7b647c5b9..0e3cd174adee1e50d0a63861286a26d325484efb 100644 --- a/paddle/operators/gather_op.cc +++ b/paddle/operators/gather_op.cc @@ -23,19 +23,19 @@ class GatherOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), - "Input(X) of GatherOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Index"), - "Input(Index) of GatherOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), - "Output(Out) of GatherOp should not be null."); + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of GatherOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Index"), + "Input(Index) of GatherOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of GatherOp should not be null."); - int batch_size = ctx.Input("Index")->dims()[0]; + int batch_size = ctx->GetInputDim("Index")[0]; PADDLE_ENFORCE_GE(batch_size, 0, "Batch size must be >0"); - framework::DDim output_dims(ctx.Input("X")->dims()); + framework::DDim output_dims(ctx->GetInputDim("X")); output_dims[0] = batch_size; - ctx.Output("Out")->Resize(output_dims); + ctx->SetOutputDim("Out", output_dims); } }; @@ -44,23 +44,20 @@ class GatherGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - auto X_grad = ctx.Output(framework::GradVarName("X")); - auto X = ctx.Input("X"); - - X_grad->Resize(X->dims()); + void InferShape(framework::InferShapeContextBase* ctx) const override { + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); } }; class GatherOpMaker : public framework::OpProtoAndCheckerMaker { public: - GatherOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + GatherOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "The source input of gather op"); AddInput("Index", "The index input of gather op"); AddOutput("Out", "The output of add op"); AddComment(R"DOC( -Gather Operator by selecting from the first axis, +Gather Operator by selecting from the first axis, Out = X[Index] )DOC"); diff --git a/paddle/operators/gaussian_random_op.cc b/paddle/operators/gaussian_random_op.cc index 5b7cbb5cc7bcb7e43b15363d37d7b8f2cbf0fbdc..05120a6e7bcfdb8641c722731f462c89e4223339 100644 --- a/paddle/operators/gaussian_random_op.cc +++ b/paddle/operators/gaussian_random_op.cc @@ -43,13 +43,10 @@ class GaussianRandomOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext& ctx) const override { - PADDLE_ENFORCE_NOT_NULL( - ctx.OutputVar("Out"), - "Output(Out) of GaussianRandomOp should not be null."); - - auto* tensor = ctx.Output("Out"); - auto dims = Attr>("dims"); + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of GaussianRandomOp should not be null."); + auto dims = ctx->Attrs().Get>("dims"); std::vector temp; temp.reserve(dims.size()); for (auto dim : dims) { @@ -57,7 +54,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel { } PADDLE_ENFORCE(dims.size() > 0UL, "dims can be one int or array. dims must be set."); - tensor->Resize(framework::make_ddim(temp)); + ctx->SetOutputDim("Out", framework::make_ddim(temp)); } }; diff --git a/paddle/operators/lookup_table_op.cc b/paddle/operators/lookup_table_op.cc index 04ac24662e9cfec6a49cd213cb76bdebc7b730c8..9b1314bfbade8551d98b0fbabb7c2968d7600db5 100644 --- a/paddle/operators/lookup_table_op.cc +++ b/paddle/operators/lookup_table_op.cc @@ -22,27 +22,26 @@ class LookupTableOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("W"), - "Input(W) of LookupTableOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Ids"), - "Input(Ids) of LookupTableOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), - "Output(Out) of LookupTableOp should not be null."); - - auto table_t = ctx.Input("W"); - auto ids_t = ctx.Input("Ids"); - auto output_t = ctx.Output("Out"); - - output_t->Resize({ids_t->dims()[0], table_t->dims()[1]}); - ctx.ShareLoD("Ids", /*->*/ "Out"); + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("W"), + "Input(W) of LookupTableOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Ids"), + "Input(Ids) of LookupTableOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of LookupTableOp should not be null."); + + auto table_dims = ctx->GetInputDim("W"); + auto ids_dims = ctx->GetInputDim("Ids"); + + ctx->SetOutputDim("Out", {ids_dims[0], table_dims[1]}); + ctx->ShareLoD("Ids", /*->*/ "Out"); } }; class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker { public: - LookupTableOpMaker(framework::OpProto *proto, - framework::OpAttrChecker *op_checker) + LookupTableOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("W", "An input represents embedding tensors," @@ -66,11 +65,9 @@ class LookupTableOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &context) const override { - auto table = context.Input("W"); - auto d_table = - context.Output(framework::GradVarName("W")); - d_table->Resize(table->dims()); + void InferShape(framework::InferShapeContextBase* ctx) const override { + auto table_dims = ctx->GetInputDim("W"); + ctx->SetOutputDim(framework::GradVarName("W"), table_dims); } }; diff --git a/paddle/operators/lstm_unit_op.cc b/paddle/operators/lstm_unit_op.cc index 3600f199770c4b8c9a6561b4c270a91bc8b20c0b..bd75b001cb87d914f6c56ea35dcb5013d68145b2 100644 --- a/paddle/operators/lstm_unit_op.cc +++ b/paddle/operators/lstm_unit_op.cc @@ -22,37 +22,36 @@ class LstmUnitOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), - "Input(X) of LSTM should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("C_prev"), - "Input(C_prev) of LSTM should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("C"), - "Output(C) of LSTM should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("H"), - "Output(H) of LSTM should not be null."); - - auto *x = ctx.Input("X"); - auto *c_prev = ctx.Input("C_prev"); - - PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank must be 2."); - PADDLE_ENFORCE(x->dims()[0] == c_prev->dims()[0], + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasInput("C_prev"), + "Input(C_prev) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("C"), + "Output(C) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("H"), + "Output(H) of LSTM should not be null."); + + auto x_dims = ctx->GetInputDim("X"); + auto c_prev_dims = ctx->GetInputDim("C_prev"); + + PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2."); + PADDLE_ENFORCE(x_dims[0] == c_prev_dims[0], "Batch size of inputs and states must be equal"); - PADDLE_ENFORCE(x->dims()[1] == c_prev->dims()[1] * 4, + PADDLE_ENFORCE(x_dims[1] == c_prev_dims[1] * 4, "Dimension of FC should equal to prev state * 4"); - int b_size = c_prev->dims()[0]; // batch size - int s_dim = c_prev->dims()[1]; // state dim - ctx.Output("C")->Resize({b_size, s_dim}); - ctx.Output("H")->Resize({b_size, s_dim}); + int b_size = c_prev_dims[0]; // batch size + int s_dim = c_prev_dims[1]; // state dim + ctx->SetOutputDim("C", {b_size, s_dim}); + ctx->SetOutputDim("H", {b_size, s_dim}); } }; template class LstmUnitOpMaker : public framework::OpProtoAndCheckerMaker { public: - LstmUnitOpMaker(framework::OpProto *proto, - framework::OpAttrChecker *op_checker) + LstmUnitOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "FC input before the non-linear activation."); AddInput( @@ -63,11 +62,11 @@ class LstmUnitOpMaker : public framework::OpProtoAndCheckerMaker { AddComment(R"DOC(Lstm-Unit Operator -Equation: +Equation: i, f, o, j = split(X) C = C_prev * sigm(f + forget_bias) + sigm(i) * tanh(j) H = C * sigm(o) - + )DOC"); AddAttr("forget_bias", "The forget bias of Lstm Unit.") .SetDefault(0.0); @@ -79,15 +78,14 @@ class LstmUnitGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("C")), - "Input(C@GRAD) should not be null"); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("H")), - "Input(H@GRAD) should not be null"); - ctx.Output(framework::GradVarName("X")) - ->Resize(ctx.Input("X")->dims()); - ctx.Output(framework::GradVarName("C_prev")) - ->Resize(ctx.Input("C_prev")->dims()); + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("C")), + "Input(C@GRAD) should not be null"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("H")), + "Input(H@GRAD) should not be null"); + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + ctx->SetOutputDim(framework::GradVarName("C_prev"), + ctx->GetInputDim("C_prev")); } }; diff --git a/paddle/operators/mean_op.cc b/paddle/operators/mean_op.cc index b04384bda81b93f5db0be3206eee10ad5e854540..d799239d4ed6d230578c77921a1a454b476b63fa 100644 --- a/paddle/operators/mean_op.cc +++ b/paddle/operators/mean_op.cc @@ -22,18 +22,18 @@ class MeanOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), - "Input(X) of MeanOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), - "Output(Out) of MeanOp should not be null."); - ctx.Output("Out")->Resize({1}); + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of MeanOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of MeanOp should not be null."); + ctx->SetOutputDim("Out", {1}); } }; class MeanOpMaker : public framework::OpProtoAndCheckerMaker { public: - MeanOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + MeanOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "The input of mean op"); AddOutput("Out", "The output of mean op").NotInGradient(); @@ -47,9 +47,8 @@ class MeanGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - ctx.Output(framework::GradVarName("X")) - ->Resize(ctx.Input("X")->dims()); + void InferShape(framework::InferShapeContextBase* ctx) const override { + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); } }; diff --git a/paddle/operators/minus_op.cc b/paddle/operators/minus_op.cc index 29cb85489bd05f6c1e7143d962eac0af26e75825..ce049d4d7bd96a6758d71b381e6e6b4edbcc8b5c 100644 --- a/paddle/operators/minus_op.cc +++ b/paddle/operators/minus_op.cc @@ -26,22 +26,22 @@ class MinusOp : public framework::OperatorWithKernel { : OperatorWithKernel(type, inputs, outputs, attrs) {} protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), - "Input(X) of MinusOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), - "Input(Y) of MinusOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), - "Output(Out) of MinusOp should not be null."); + void InferShape(framework::InferShapeContextBase *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of MinusOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Y"), + "Input(Y) of MinusOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of MinusOp should not be null."); - auto *left_tensor = ctx.Input("X"); - auto *right_tensor = ctx.Input("Y"); + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); PADDLE_ENFORCE_EQ( - left_tensor->numel(), right_tensor->numel(), + x_dims, y_dims, "Minus operator must take two tensor with same num of elements"); - ctx.Output("Out")->Resize(left_tensor->dims()); - ctx.ShareLoD("X", /*->*/ "Out"); + ctx->SetOutputDim("Out", x_dims); + ctx->ShareLoD("X", /*->*/ "Out"); } }; diff --git a/paddle/operators/modified_huber_loss_op.cc b/paddle/operators/modified_huber_loss_op.cc index 8606c0d1e1bf7a52299528d30af0367d9f93edd2..84212a2b3be1ac3664ebd77c7a0ae4d86abad3a0 100644 --- a/paddle/operators/modified_huber_loss_op.cc +++ b/paddle/operators/modified_huber_loss_op.cc @@ -22,20 +22,19 @@ class ModifiedHuberLossOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext& context) const override { - PADDLE_ENFORCE_NOT_NULL(context.InputVar("X"), "X must be initialized."); - PADDLE_ENFORCE_NOT_NULL(context.InputVar("Y"), "Y must be initialized."); + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "X must be initialized."); + PADDLE_ENFORCE(ctx->HasInput("Y"), "Y must be initialized."); - auto* x = context.Input("X"); - auto* y = context.Input("Y"); + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); - PADDLE_ENFORCE_EQ(x->dims(), y->dims(), - "The shape of X and Y must be the same."); - PADDLE_ENFORCE_EQ(x->dims().size(), 2, "The tensor rank of X must be 2."); - PADDLE_ENFORCE_EQ(x->dims()[1], 1, "The 2nd dimension of X must be 1."); + PADDLE_ENFORCE_EQ(x_dims, y_dims, "The shape of X and Y must be the same."); + PADDLE_ENFORCE_EQ(x_dims.size(), 2, "The tensor rank of X must be 2."); + PADDLE_ENFORCE_EQ(x_dims[1], 1, "The 2nd dimension of X must be 1."); - context.Output("IntermediateVal")->Resize(x->dims()); - context.Output("Out")->Resize({x->dims()[0], 1}); + ctx->SetOutputDim("IntermediateVal", x_dims); + ctx->SetOutputDim("Out", {x_dims[0], 1}); } }; @@ -75,27 +74,28 @@ class ModifiedHuberLossGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext& context) const override { - auto* x = context.Input("X"); - auto* y = context.Input("Y"); - auto* intermediate_val = context.Input("IntermediateVal"); - auto* out_grad = context.Input(framework::GradVarName("Out")); - auto* x_grad = - context.Output(framework::GradVarName("X")); - - PADDLE_ENFORCE_NOT_NULL(x, "X must be initialized."); - PADDLE_ENFORCE_NOT_NULL(y, "Y must be initialized."); - PADDLE_ENFORCE_NOT_NULL(intermediate_val, - "Intermediate value must not be null."); - PADDLE_ENFORCE_NOT_NULL(out_grad, "Input(Out@Grad) must not be null."); + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "X must be initialized."); + PADDLE_ENFORCE(ctx->HasInput("Y"), "Y must be initialized."); + PADDLE_ENFORCE(ctx->HasInput("IntermediateVal"), + "Intermediate value must not be null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@Grad) must not be null."); + + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); + auto intermediate_dims = ctx->GetInputDim("IntermediateVal"); + auto out_grad_dims = ctx->GetInputDim(framework::GradVarName("Out")); PADDLE_ENFORCE_EQ( - intermediate_val->dims(), x->dims(), + intermediate_dims, x_dims, "The shape of X and intermediate value must be the same."); - PADDLE_ENFORCE_EQ(out_grad->dims(), x->dims(), + PADDLE_ENFORCE_EQ(out_grad_dims, x_dims, "The shape of Input(Out@Grad) and X must be the same."); - if (x_grad) x_grad->Resize(x->dims()); + if (ctx->HasOutput(framework::GradVarName("X"))) { + ctx->SetOutputDim(framework::GradVarName("X"), x_dims); + } } }; diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc index 7047718a3f1bf7e9598952efa1d9bcb20d5cf5b4..9858c4d9c2195c7bd0e767aaa86a950e0a791443 100644 --- a/paddle/operators/mul_op.cc +++ b/paddle/operators/mul_op.cc @@ -24,27 +24,23 @@ class MulOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), - "Input(X) of MulOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), - "Input(Y) of MulOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), - "Output(Out) of MulOp should not be null."); - - auto x_dims = ctx.Input("X")->dims(); - auto y_dims = ctx.Input("Y")->dims(); - int x_num_col_dims = Attr("x_num_col_dims"); - int y_num_col_dims = Attr("y_num_col_dims"); + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of MulOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) of MulOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of MulOp should not be null."); + + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); + int x_num_col_dims = ctx->Attrs().Get("x_num_col_dims"); + int y_num_col_dims = ctx->Attrs().Get("y_num_col_dims"); PADDLE_ENFORCE(x_dims.size() > x_num_col_dims, - "The rank of input tensor X(%s) should be larger than " - "`mul_op`'s `x_num_col_dims`.", - ctx.op().Input("X")); + "The rank of input tensor X should be larger than " + "`mul_op`'s `x_num_col_dims`."); PADDLE_ENFORCE(y_dims.size() > y_num_col_dims, - "The rank of input tensor Y(%s) should be larger than " - "`mul_op`'s `y_num_col_dims`.", - ctx.op().Input("Y")); + "The rank of input tensor Y should be larger than " + "`mul_op`'s `y_num_col_dims`."); auto x_mat_dims = framework::flatten_to_2d(x_dims, x_num_col_dims); auto y_mat_dims = framework::flatten_to_2d(y_dims, y_num_col_dims); @@ -52,24 +48,23 @@ class MulOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ( x_mat_dims[1], y_mat_dims[0], "First matrix's width must be equal with second matrix's height."); - ctx.Output("Out")->Resize( - {x_mat_dims[0], y_mat_dims[1]}); - ctx.ShareLoD("X", /*->*/ "Out"); + ctx->SetOutputDim("Out", {x_mat_dims[0], y_mat_dims[1]}); + ctx->ShareLoD("X", /*->*/ "Out"); } }; class MulOpMaker : public framework::OpProtoAndCheckerMaker { public: - MulOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + MulOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "The first input of mul op"); AddInput("Y", "The second input of mul op"); AddOutput("Out", "The output of mul op"); AddAttr( "x_num_col_dims", - R"DOC(mul_op can take tensors with more than two dimensions as input `X`, - in that case, tensors will be reshaped to a matrix. The matrix's first - dimension(column length) will be the product of tensor's last + R"DOC(mul_op can take tensors with more than two dimensions as input `X`, + in that case, tensors will be reshaped to a matrix. The matrix's first + dimension(column length) will be the product of tensor's last `num_col_dims` dimensions, and the matrix's second dimension(row length) will be the product of tensor's first `rank - num_col_dims` dimensions. )DOC") @@ -100,16 +95,14 @@ class MulOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null"); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) should not be null"); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), - "Input(Out@GRAD) should not be null"); - auto x_dims = ctx.Input("X")->dims(); - auto y_dims = ctx.Input("Y")->dims(); - auto out_dims = ctx.Input(framework::GradVarName("Out"))->dims(); - auto *x_grad = ctx.Output(framework::GradVarName("X")); - auto *y_grad = ctx.Output(framework::GradVarName("Y")); + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); + PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null"); + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); + auto out_dims = ctx->GetInputDim(framework::GradVarName("Out")); auto x_mat_dims = framework::flatten_to_2d(x_dims, Attr("x_num_col_dims")); @@ -125,8 +118,15 @@ class MulOpGrad : public framework::OperatorWithKernel { "The second dimension of Out@GRAD must equal to the second " "dimension of the second operand."); - if (x_grad) x_grad->Resize(x_dims); - if (y_grad) y_grad->Resize(y_dims); + auto x_grad_name = framework::GradVarName("X"); + auto y_grad_name = framework::GradVarName("Y"); + + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, x_dims); + } + if (ctx->HasOutput(y_grad_name)) { + ctx->SetOutputDim(y_grad_name, y_dims); + } } }; diff --git a/paddle/operators/multiplex_op.cc b/paddle/operators/multiplex_op.cc index 7b50444d16dc57fd14b918d1159e3e21ecd1f1c4..9896d269ccc86d8fdc3bf6375e44ef5bf3e6b9c7 100644 --- a/paddle/operators/multiplex_op.cc +++ b/paddle/operators/multiplex_op.cc @@ -24,41 +24,38 @@ class MultiplexOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Ids"), - "Input(Ids) shouldn't be null."); - PADDLE_ENFORCE(!ctx.MultiInputVar("X").empty(), + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Ids"), "Input(Ids) shouldn't be null."); + PADDLE_ENFORCE(!ctx->Inputs("X").empty(), "MultiInput(X) shouldn't be empty."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), - "Output(Out) shouldn't be null."); - auto ids_dim = ctx.Input("Ids")->dims(); + PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) shouldn't be null."); + auto ids_dim = ctx->GetInputDim("Ids"); PADDLE_ENFORCE( ids_dim.size() == 2 && ids_dim[1] == 1, "The index tensor must be a vector with size batchSize x 1."); - auto ins = ctx.MultiInput("X"); - auto *out = ctx.Output("Out"); - auto num_ins = ins.size(); + auto ins_dims = ctx->GetInputsDim("X"); + auto num_ins = ins_dims.size(); PADDLE_ENFORCE(num_ins > 1, "multiplex operator should have more than " "one candidate input tensors."); - auto in_dim = ins[0]->dims(); + auto in_dim = ins_dims[0]; PADDLE_ENFORCE(in_dim.size() >= 2, "The rank of candidate tensors must be not less than 2."); for (size_t i = 1; i < num_ins; i++) { - auto dim = ins[i]->dims(); + auto dim = ins_dims[i]; PADDLE_ENFORCE(in_dim == dim, "All the candidate tensors must have the same size."); } - out->Resize(in_dim); + ctx->SetOutputDim("Out", in_dim); } }; class MultiplexOpMaker : public framework::OpProtoAndCheckerMaker { public: - MultiplexOpMaker(framework::OpProto *proto, - framework::OpAttrChecker *op_checker) + MultiplexOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("Ids", "The index tensor of multiplex operator."); AddInput("X", "The candidate tensors of multiplex operator.") @@ -88,21 +85,19 @@ class MultiplexGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE(!ctx.MultiInputVar("X").empty(), - "Input(X) should not be null."); - PADDLE_ENFORCE(!ctx.MultiOutputVar(framework::GradVarName("X")).empty(), + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(!ctx->Inputs("X").empty(), "Input(X) should not be null."); + PADDLE_ENFORCE(!ctx->Outputs(framework::GradVarName("X")).empty(), "Output(X@Grad) should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), - "Input(Out@GRAD) should not be null."); - auto d_ins = ctx.MultiOutput(framework::GradVarName("X")); - auto ins = ctx.MultiInput("X"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null."); + std::vector d_ins; + auto ins = ctx->GetInputsDim("X"); // No need to compute gradient for Input(Ids) for (size_t i = 0; i < ins.size(); i++) { - if (d_ins[i]) { - d_ins[i]->Resize(ins[i]->dims()); - } + d_ins.push_back(ins[i]); } + ctx->SetOutputsDim(framework::GradVarName("X"), d_ins); } }; diff --git a/paddle/operators/pad_op.cc b/paddle/operators/pad_op.cc index 375d8a35acc0716259071c31bc332fdf5aabce1c..04ebb14f6ee6c73f48aa2f75811a22f9b8a25006 100644 --- a/paddle/operators/pad_op.cc +++ b/paddle/operators/pad_op.cc @@ -24,14 +24,13 @@ class PadOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), - "Input(X) of PadOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), - "Output(Out) of PadOp should not be null."); - - auto x_dim = ctx.Input("X")->dims(); - auto paddings = Attr>("paddings"); + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of PadOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of PadOp should not be null."); + + auto x_dim = ctx->GetInputDim("X"); + auto paddings = ctx->Attrs().Get>("paddings"); PADDLE_ENFORCE_EQ(x_dim.size() * 2, int64_t(paddings.size()), "Size of paddings should be equal to 2 * dimension size " "of input tensor."); @@ -39,19 +38,18 @@ class PadOp : public framework::OperatorWithKernel { for (int i = 0; i < x_dim.size(); ++i) { out_dims[i] = x_dim[i] + paddings[i * 2] + paddings[i * 2 + 1]; } - ctx.Output("Out")->Resize( - framework::make_ddim(out_dims)); + ctx->SetOutputDim("Out", framework::make_ddim(out_dims)); if (out_dims[0] == x_dim[0]) { // Only pass LoD when the first dimension is equal between // output and input. - ctx.ShareLoD("X", /*->*/ "Out"); + ctx->ShareLoD("X", /*->*/ "Out"); } } }; class PadOpMaker : public framework::OpProtoAndCheckerMaker { public: - PadOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + PadOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "The input of pad op. " @@ -68,15 +66,15 @@ Given: X = [[1, 2], [3, 4]] -and +and paddings = [0, 1, 1, 2] and - -pad_value = 0 -then we get +pad_value = 0 + +then we get Out = [[0, 1, 2, 0, 0] [0, 3, 4, 0, 0] @@ -101,14 +99,14 @@ class PadOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null"); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), - "Input(Out@GRAD) should not be null"); - auto x_dims = ctx.Input("X")->dims(); - auto *x_g = ctx.Output(framework::GradVarName("X")); - if (x_g != nullptr) { - x_g->Resize(x_dims); + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null"); + auto x_dims = ctx->GetInputDim("X"); + auto x_grad_name = framework::GradVarName("X"); + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, x_dims); } } }; diff --git a/paddle/operators/prelu_op.cc b/paddle/operators/prelu_op.cc index 912196c190b5ddbd4e3482a5314e949186b94368..1692464f2833a59243ccc1598422180262a59282 100644 --- a/paddle/operators/prelu_op.cc +++ b/paddle/operators/prelu_op.cc @@ -26,19 +26,14 @@ class PReluOp : public framework::OperatorWithKernel { : OperatorWithKernel(type, inputs, outputs, attrs) {} protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null"); - auto *in = ctx.Input("X"); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Alpha"), - "Input(Alpha) should not be null"); - auto *alpha = ctx.Input("Alpha"); - PADDLE_ENFORCE(alpha->numel() == 1, "Size of weight Alpha must be one."); - - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), - "Output(Out) should not be null"); - auto *out = ctx.Output("Out"); - out->Resize(in->dims()); - ctx.ShareLoD("X", /*->*/ "Out"); + void InferShape(framework::InferShapeContextBase *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); + PADDLE_ENFORCE(ctx->HasInput("Alpha"), "Input(Alpha) should not be null"); + PADDLE_ENFORCE(product(ctx->GetInputDim("Alpha")) == 1, + "Size of weight Alpha must be one."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null"); + ctx->SetOutputDim("Out", ctx->GetInputDim("X")); + ctx->ShareLoD("X", /*->*/ "Out"); } }; @@ -68,19 +63,13 @@ class PReluGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), - "Input(Out@GRAD) should not be null"); - auto *dx = ctx.Output(framework::GradVarName("X")); - auto *x = ctx.Input("X"); - - auto *dalpha = - ctx.Output(framework::GradVarName("Alpha")); - auto *alpha = ctx.Input("Alpha"); - - dx->Resize(x->dims()); - dalpha->Resize(alpha->dims()); + void InferShape(framework::InferShapeContextBase *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null"); + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + ctx->SetOutputDim(framework::GradVarName("Alpha"), + ctx->GetInputDim("Alpha")); } }; diff --git a/paddle/operators/rank_loss_op.cc b/paddle/operators/rank_loss_op.cc index 39af08c8751c3b95cf5fdef7395186a0176a20a2..1ba22006f27abc963e7f161636a964863513a40c 100644 --- a/paddle/operators/rank_loss_op.cc +++ b/paddle/operators/rank_loss_op.cc @@ -25,22 +25,21 @@ class RankLossOp : public framework::OperatorWithKernel { : OperatorWithKernel(type, inputs, outputs, attrs) {} protected: - void InferShape(const framework::InferShapeContext &ctx) const override { + void InferShape(framework::InferShapeContextBase *ctx) const override { // input check - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), - "Input(Label) shouldn't be null"); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Left"), - "Input(Left) shouldn't be null"); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Right"), - "Input(Right) shouldn't be null"); - auto label_dims = ctx.Input("Label")->dims(); - auto left_dims = ctx.Input("Left")->dims(); - auto right_dims = ctx.Input("Right")->dims(); + PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) shouldn't be null"); + PADDLE_ENFORCE(ctx->HasInput("Left"), "Input(Left) shouldn't be null"); + PADDLE_ENFORCE(ctx->HasInput("Right"), "Input(Right) shouldn't be null"); + + auto label_dims = ctx->GetInputDim("Label"); + auto left_dims = ctx->GetInputDim("Left"); + auto right_dims = ctx->GetInputDim("Right"); + PADDLE_ENFORCE((label_dims == left_dims) && (left_dims == right_dims), "All inputs must have the same size"); PADDLE_ENFORCE((label_dims.size() == 2) && (label_dims[1] == 1), "All inputs must be row vector with size batch_size x 1."); - ctx.Output("Out")->Resize(label_dims); + ctx->SetOutputDim("Out", label_dims); } }; @@ -91,25 +90,22 @@ class RankLossGradOp : public framework::OperatorWithKernel { : OperatorWithKernel(type, inputs, outputs, attrs) {} protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), - "Input(Label) shouldn't be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Left"), - "Input(Left) shouldn't be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Right"), - "Input(Right) shouldn't be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), - "Input(Out@GRAD) shouldn't be null."); - auto dims = ctx.Input("Left")->dims(); - auto *left_grad = - ctx.Output(framework::GradVarName("Left")); - auto *right_grad = - ctx.Output(framework::GradVarName("Right")); - if (left_grad) { - left_grad->Resize(dims); + void InferShape(framework::InferShapeContextBase *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) shouldn't be null."); + PADDLE_ENFORCE(ctx->HasInput("Left"), "Input(Left) shouldn't be null."); + PADDLE_ENFORCE(ctx->HasInput("Right"), "Input(Right) shouldn't be null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) shouldn't be null."); + auto dims = ctx->GetInputDim("Left"); + auto left_grad_name = framework::GradVarName("Left"); + auto right_grad_name = framework::GradVarName("Right"); + + if (ctx->HasOutput(left_grad_name)) { + ctx->SetOutputDim(left_grad_name, dims); } - if (right_grad) { - right_grad->Resize(dims); + + if (ctx->HasOutput(right_grad_name)) { + ctx->SetOutputDim(right_grad_name, dims); } } }; diff --git a/paddle/operators/reshape_op.cc b/paddle/operators/reshape_op.cc index ddb93007e21e4d1ae4be3650019c8bc6a680252d..a3c3fa2716ad9f6487e3eff2d98b2c76d964ddef 100644 --- a/paddle/operators/reshape_op.cc +++ b/paddle/operators/reshape_op.cc @@ -26,14 +26,14 @@ class ReshapeOp : public framework::OperatorWithKernel { : OperatorWithKernel(type, inputs, outputs, attrs) {} protected: - void InferShape(const framework::InferShapeContext &ctx) const override { + void InferShape(framework::InferShapeContextBase *ctx) const override { // input check - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), - "Input(X) of ReshapeOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), - "Output(Out) of ReshapeOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of ReshapeOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of ReshapeOp should not be null."); - auto shape = ctx.Attr>("shape"); + auto shape = ctx->Attrs().Get>("shape"); PADDLE_ENFORCE(shape.size() > 0, "Attr(shape) shouldn't be empty."); for (auto dim : shape) { PADDLE_ENFORCE(dim > 0, "Each dimension of shape must be positive."); @@ -41,8 +41,8 @@ class ReshapeOp : public framework::OperatorWithKernel { // capacity check int64_t capacity = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); - auto *in = ctx.Input("X"); - int64_t in_size = framework::product(in->dims()); + auto x_dims = ctx->GetInputDim("X"); + int64_t in_size = framework::product(x_dims); PADDLE_ENFORCE_EQ(capacity, in_size, "The size of Input(X) mismatches with Attr(shape)."); // resize output @@ -50,11 +50,11 @@ class ReshapeOp : public framework::OperatorWithKernel { std::transform(shape.begin(), shape.end(), shape_int64.begin(), [](int a) { return static_cast(a); }); auto out_dims = framework::make_ddim(shape_int64); - ctx.Output("Out")->Resize(out_dims); - if (shape[0] == in->dims()[0]) { + ctx->SetOutputDim("Out", out_dims); + if (shape[0] == x_dims[0]) { // Only pass LoD when the first dimension is equal between // output and input. - ctx.ShareLoD("X", /*->*/ "Out"); + ctx->ShareLoD("X", /*->*/ "Out"); } } }; @@ -76,7 +76,7 @@ Given a 2-D tensor X with 2 rows and 2 columns [[1, 2], [3, 4]] -with target shape = [1, 4], the reshape operator will transform +with target shape = [1, 4], the reshape operator will transform the tensor X into a 1-D tensor: [1, 2, 3, 4] @@ -94,13 +94,11 @@ class ReshapeGradOp : public framework::OperatorWithKernel { : OperatorWithKernel(type, inputs, outputs, attrs) {} protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) shouldn't be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), - "Input(Out@GRAD) shouldn't be null."); - auto dims = ctx.Input("X")->dims(); - auto *d_in = ctx.Output(framework::GradVarName("X")); - d_in->Resize(dims); + void InferShape(framework::InferShapeContextBase *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) shouldn't be null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) shouldn't be null."); + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); } }; diff --git a/paddle/operators/rowwise_add_op.cc b/paddle/operators/rowwise_add_op.cc index fc3ad721f210213491617452141dfa8834b067c0..1fcf0959dffd6a68d97dec4e2b5b509d06c0d09c 100644 --- a/paddle/operators/rowwise_add_op.cc +++ b/paddle/operators/rowwise_add_op.cc @@ -24,16 +24,16 @@ class RowwiseAddOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), - "Input(X) of RowwiseAddOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("b"), - "Input(b) of RowwiseAddOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), - "Output(Out) of RowwiseAddOp should not be null."); - - auto x_dims = ctx.Input("X")->dims(); - auto b_dims = ctx.Input("b")->dims(); + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of RowwiseAddOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("b"), + "Input(b) of RowwiseAddOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of RowwiseAddOp should not be null."); + + auto x_dims = ctx->GetInputDim("X"); + auto b_dims = ctx->GetInputDim("b"); PADDLE_ENFORCE_GT( x_dims.size(), b_dims.size(), "The rank of input `X` must be larger than the one of input `b`."); @@ -43,16 +43,17 @@ class RowwiseAddOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ( framework::slice_ddim(x_dims, num_col_dims, x_dims.size()), b_dims, "The width of two operands must be same"); - PADDLE_ENFORCE_EQ(ctx.OutputSize("Out"), 1, "The output size must be 1"); - ctx.Output("Out")->Resize(x_dims); - ctx.ShareLoD("X", /*->*/ "Out"); + PADDLE_ENFORCE_EQ(ctx->Outputs("Out").size(), 1, + "The output size must be 1"); + ctx->SetOutputDim("Out", x_dims); + ctx->ShareLoD("X", /*->*/ "Out"); } }; class RowwiseAddOpMaker : public framework::OpProtoAndCheckerMaker { public: - RowwiseAddOpMaker(framework::OpProto *proto, - framework::OpAttrChecker *op_checker) + RowwiseAddOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "The left input of row-wise add op, must be matrix"); AddInput("b", "The right input of row-wise add op, must be vector"); @@ -69,25 +70,29 @@ class RowwiseAddGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "X should not be null"); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("b"), "b should not be null"); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), - "Input(Out@GRAD) should not be null"); - auto x_dims = ctx.Input("X")->dims(); - auto b_dims = ctx.Input("b")->dims(); + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "X should not be null"); + PADDLE_ENFORCE(ctx->HasInput("b"), "b should not be null"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null"); + auto x_dims = ctx->GetInputDim("X"); + auto b_dims = ctx->GetInputDim("b"); PADDLE_ENFORCE_GT( x_dims.size(), b_dims.size(), "The rank of input `X` must be larger than the one of input `b`."); - int num_col_dims = x_dims.size() - b_dims.size(); + int64_t num_col_dims = x_dims.size() - b_dims.size(); PADDLE_ENFORCE_EQ( framework::slice_ddim(x_dims, num_col_dims, x_dims.size()), b_dims, "The width of two operands must be same"); - auto *dx = ctx.Output(framework::GradVarName("X")); - auto *db = ctx.Output(framework::GradVarName("b")); - if (dx) dx->Resize(x_dims); - if (db) db->Resize(b_dims); + auto x_grad_name = framework::GradVarName("X"); + auto b_grad_name = framework::GradVarName("b"); + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, x_dims); + } + if (ctx->HasOutput(b_grad_name)) { + ctx->SetOutputDim(b_grad_name, b_dims); + } } }; diff --git a/paddle/operators/scale_op.cc b/paddle/operators/scale_op.cc index 1ae77a9722ef1a5548a6c4100c32fdddcee8c5cd..e92501e12834b92875f494de401672344f50e3b5 100644 --- a/paddle/operators/scale_op.cc +++ b/paddle/operators/scale_op.cc @@ -26,16 +26,13 @@ class ScaleOp : public framework::OperatorWithKernel { : OperatorWithKernel(type, inputs, outputs, attrs) {} protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), - "Input(X) of ScaleOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), - "Output(Out) of ScaleOp should not be null."); - - auto *in = ctx.Input("X"); - auto *out = ctx.Output("Out"); - out->Resize(in->dims()); - ctx.ShareLoD("X", /*->*/ "Out"); + void InferShape(framework::InferShapeContextBase *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of ScaleOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of ScaleOp should not be null."); + ctx->SetOutputDim("Out", ctx->GetInputDim("X")); + ctx->ShareLoD("X", /*->*/ "Out"); } }; diff --git a/paddle/operators/scatter_op.cc b/paddle/operators/scatter_op.cc index 3f02081a060281dec533c02b346f0667da28b8c3..3fc4a39ebc5526bfed61ba667c3cdc214cdd056c 100644 --- a/paddle/operators/scatter_op.cc +++ b/paddle/operators/scatter_op.cc @@ -23,29 +23,30 @@ class ScatterOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Ref"), - "Input(Ref) of ScatterOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Index"), - "Input(Index) of ScatterOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Updates"), - "Input(Updates) of ScatterOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), - "Output(Out) of ScatterOp should not be null."); + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Ref"), + "Input(Ref) of ScatterOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Index"), + "Input(Index) of ScatterOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Updates"), + "Input(Updates) of ScatterOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of ScatterOp should not be null."); - PADDLE_ENFORCE_EQ(ctx.Input("Index")->dims().size(), 1, + auto updates_dims = ctx->GetInputDim("Updates"); + auto ref_dims = ctx->GetInputDim("Ref"); + PADDLE_ENFORCE_EQ(ctx->GetInputDim("Index").size(), 1, "Update Index should be 1-D."); - PADDLE_ENFORCE_EQ(ctx.Input("Ref")->dims().size(), - ctx.Input("Updates")->dims().size(), + PADDLE_ENFORCE_EQ(ref_dims.size(), updates_dims.size(), "Reference and Updates should have the same shape size"); - PADDLE_ENFORCE_EQ(ctx.Input("Updates")->dims()[0], - ctx.Input("Index")->dims()[0], + PADDLE_ENFORCE_EQ(ctx->GetInputDim("Updates")[0], + ctx->GetInputDim("Index")[0], "Updates and Index should have same batch-size."); - framework::DDim data_dim(ctx.Input("Updates")->dims()); - for (int i = 1; i < data_dim.size(); ++i) - PADDLE_ENFORCE_EQ(data_dim[i], ctx.Input("Updates")->dims()[i]); - ctx.Output("Out")->Resize( - ctx.Input("Ref")->dims()); + framework::DDim data_dim(updates_dims); + for (int i = 1; i < data_dim.size(); ++i) { + PADDLE_ENFORCE_EQ(data_dim[i], updates_dims[i]); + } + ctx->SetOutputDim("Out", ref_dims); } }; @@ -54,22 +55,17 @@ class ScatterGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - auto *dUpdates = - ctx.Output(framework::GradVarName("Updates")); - auto *Updates = ctx.Input("Updates"); - auto *dRef = ctx.Output(framework::GradVarName("Ref")); - auto *Ref = ctx.Input("Ref"); - - dRef->Resize(Ref->dims()); - dUpdates->Resize(Updates->dims()); + void InferShape(framework::InferShapeContextBase* ctx) const override { + ctx->SetOutputDim(framework::GradVarName("Updates"), + ctx->GetInputDim("Updates")); + ctx->SetOutputDim(framework::GradVarName("Ref"), ctx->GetInputDim("Ref")); } }; class ScatterOpMaker : public framework::OpProtoAndCheckerMaker { public: - ScatterOpMaker(framework::OpProto *proto, - framework::OpAttrChecker *op_checker) + ScatterOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("Ref", "The source input of scatter op"); AddInput("Index", @@ -77,13 +73,14 @@ class ScatterOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("Updates", "The updated value of updates op"); AddOutput("Out", "The output of add op"); AddComment(R"DOC( -Scatter Operator by selecting from the first axis, +Scatter Operator by selecting from the first axis, Out = Ref Out[Index] = Ref[Index] + Updates )DOC"); } }; + } // namespace operators } // namespace paddle diff --git a/paddle/operators/sequence_pool_op.cc b/paddle/operators/sequence_pool_op.cc index 73f9cb879a2ef690909428b3b672b12717a6a02c..17685ea654715f6996e17f6228f266c3aa1ee424 100644 --- a/paddle/operators/sequence_pool_op.cc +++ b/paddle/operators/sequence_pool_op.cc @@ -22,23 +22,12 @@ class SequencePoolOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext& ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), - "Input(X) of SequencePoolOp should not be null."); - PADDLE_ENFORCE_NOT_NULL( - ctx.OutputVar("Out"), - "Output(Out) of SequencePoolOp should not be null."); - - auto* x = ctx.Input("X"); - auto dims = x->dims(); - auto lod = x->lod(); - PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now."); - PADDLE_ENFORCE_GE( - dims[0], - /*batch size = */ static_cast(lod[0].size() - 1), - "The first dimension of Input(X) must be large than batch size."); - dims[0] = lod[0].size() - 1; - ctx.Output("Out")->Resize({dims}); + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of SequenceAvgPoolOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of SequenceAvgPoolOp should not be null."); + ctx->SetOutputDim("Out", ctx->GetInputDim("X")); } }; @@ -61,17 +50,17 @@ class SequencePoolOpMaker : public framework::OpProtoAndCheckerMaker { SequencePoolOp pools features of all time-steps of each instance. For a mini-batch of 3 variable lengths sentences, containing 2, 3, and 2 time-steps: - + Assume X is a [7,M,N] float LoDTensor, and X->lod()[0] = [0, 2, 5, 7]. - Besides, for the sake of simplicity, we assume M=1 and N=1, + Besides, for the sake of simplicity, we assume M=1 and N=1, and the value of X = [[1, 3], [2, 4, 6], [5, 1]]. Thus, Out is a [3,1,1] float LoDTensor, but Out->lod() is nullptr. - And for different strategy, the value of Out is as follows: + And for different strategy, the value of Out is as follows: - AVERAGE: [2, 4, 3], where 2=(1+3)/2, 4=(2+4+6)/3, 3=(5+1)/2 - SUM: [4, 12, 6], where 4=1+3, 12=2+4+6, 6=5+1 - - SQRT: [2.82, 6.93, 4.24], where 2.82=(1+3)/sqrt(2), + - SQRT: [2.82, 6.93, 4.24], where 2.82=(1+3)/sqrt(2), 6.93=(2+4+6)/sqrt(3), 4.24=(5+1)/sqrt(2) - MAX: [3, 6, 5], where 3=max(1,3), 6=max(2,4,6), 5=max(5,1) - LAST: [3, 6, 1], where 3=last(1,3), 6=last(2,4,6), 1=last(5,1) @@ -85,22 +74,18 @@ class SequencePoolGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext& ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), - "Gradient of Out should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), - "The input X should not be null."); - auto og_dims = - ctx.Input(framework::GradVarName("Out"))->dims(); - auto x_dims = ctx.Input("X")->dims(); + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Gradient of Out should not be null."); + PADDLE_ENFORCE(ctx->HasInput("X"), "The input X should not be null."); + auto og_dims = ctx->GetInputDim(framework::GradVarName("Out")); + auto x_dims = ctx->GetInputDim("X"); PADDLE_ENFORCE_EQ(og_dims.size(), x_dims.size(), "The rank of output grad must equal to Input(X)."); for (int64_t i = 1; i < og_dims.size(); ++i) { PADDLE_ENFORCE_EQ(og_dims[i], x_dims[i], "The dimension mismatch."); } - auto* x_grad = - ctx.Output(framework::GradVarName("X")); - x_grad->Resize(x_dims); + ctx->SetOutputDim(framework::GradVarName("X"), x_dims); } }; diff --git a/paddle/operators/sequence_pool_op.h b/paddle/operators/sequence_pool_op.h index 231614b4c1cb0eb1901b1720e933aed5cbb25f77..cb80586e88f8d9e31b7b91a54f5e05ac6fa73f0f 100644 --- a/paddle/operators/sequence_pool_op.h +++ b/paddle/operators/sequence_pool_op.h @@ -46,16 +46,27 @@ class SequencePoolKernel : public framework::OpKernel { int strategy = context.Attr("strategy"); auto dims = in->dims(); - auto lod = in->lod()[0]; + auto lod = in->lod(); int64_t w = in->numel() / dims[0]; + // InferShape by lod + PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now."); + PADDLE_ENFORCE_GE( + dims[0], + /*batch size = */ static_cast(lod[0].size() - 1), + "The first dimension of Input(X) must be large than batch size."); + dims[0] = lod[0].size() - 1; + out->Resize({dims}); + + auto lod_level_0 = lod[0]; + out->mutable_data(context.GetPlace()); auto place = context.GetEigenDevice(); - for (int i = 0; i < static_cast(lod.size()) - 1; ++i) { - Tensor in_t = - in->Slice(static_cast(lod[i]), static_cast(lod[i + 1])); + for (int i = 0; i < static_cast(lod_level_0.size()) - 1; ++i) { + Tensor in_t = in->Slice(static_cast(lod_level_0[i]), + static_cast(lod_level_0[i + 1])); Tensor out_t = out->Slice(i, i + 1); - int64_t h = static_cast(lod[i + 1] - lod[i]); + int64_t h = static_cast(lod_level_0[i + 1] - lod_level_0[i]); auto in_e = EigenMatrix::From(in_t, framework::make_ddim({h, w})); auto out_e = EigenVector::Flatten(out_t); diff --git a/paddle/operators/sgd_op.cc b/paddle/operators/sgd_op.cc index b063e2427217f20eb89f7cd1af0354ad0e400feb..3bce95535cf10c0df95b503c6e362b3f0ba2e723 100644 --- a/paddle/operators/sgd_op.cc +++ b/paddle/operators/sgd_op.cc @@ -22,19 +22,18 @@ class SGDOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("param"), - "Input(param) of SGDOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("grad"), - "Input(grad) of SGDOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("param_out"), - "Output(param_out) of SGDOp should not be null."); - - PADDLE_ENFORCE_EQ(ctx.Input("param")->dims(), - ctx.Input("grad")->dims(), + void InferShape(framework::InferShapeContextBase *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("param"), + "Input(param) of SGDOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("grad"), + "Input(grad) of SGDOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("param_out"), + "Output(param_out) of SGDOp should not be null."); + + auto param_dim = ctx->GetInputDim("param"); + PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("grad"), "Two input of SGD Op's dimension must be same."); - ctx.Output("param_out") - ->Resize(ctx.Input("param")->dims()); + ctx->SetOutputDim("param_out", param_dim); } }; diff --git a/paddle/operators/smooth_l1_loss_op.cc b/paddle/operators/smooth_l1_loss_op.cc index ae6d1c80b300690b070024d6266a1b99bf2ef04f..2d197e3b1b763fa87939623d47728aab3bff7cd1 100644 --- a/paddle/operators/smooth_l1_loss_op.cc +++ b/paddle/operators/smooth_l1_loss_op.cc @@ -22,33 +22,28 @@ class SmoothL1LossOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext& ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "X must be initialized."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Y must be initialized."); - - auto* x = ctx.Input("X"); - auto* y = ctx.Input("Y"); - PADDLE_ENFORCE_EQ(x->dims(), y->dims(), - "The shape of X and Y must be the same."); - PADDLE_ENFORCE_GE(x->dims().size(), 2, + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "X must be initialized."); + PADDLE_ENFORCE(ctx->HasInput("Y"), "Y must be initialized."); + + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); + PADDLE_ENFORCE_EQ(x_dims, y_dims, "The shape of X and Y must be the same."); + PADDLE_ENFORCE_GE(x_dims.size(), 2, "The tensor rank of X must be at least 2."); - auto* inside_weight = ctx.Input("InsideWeight"); - if (inside_weight) { - auto* outside_weight = ctx.Input("OutsideWeight"); - PADDLE_ENFORCE_NOT_NULL(outside_weight, - "If weights are provided, must specify both " - "inside and outside weights."); - PADDLE_ENFORCE_EQ(inside_weight->dims(), x->dims(), + if (ctx->HasInput("InsideWeight")) { + PADDLE_ENFORCE(ctx->HasInput("OutsideWeight"), + "If weights are provided, must specify both " + "inside and outside weights."); + PADDLE_ENFORCE_EQ(ctx->GetInputDim("InsideWeight"), x_dims, "The shape of InsideWeight must be same as X."); - PADDLE_ENFORCE_EQ(outside_weight->dims(), x->dims(), + PADDLE_ENFORCE_EQ(ctx->GetInputDim("OutsideWeight"), x_dims, "The shape of OutsideWeight must be same as X."); } - auto* diff = ctx.Output("Diff"); - auto* out = ctx.Output("Out"); - diff->Resize(x->dims()); + ctx->SetOutputDim("Diff", x_dims); // loss is a two-rank tensor - out->Resize({x->dims()[0], 1}); + ctx->SetOutputDim("Out", {x_dims[0], 1}); } }; @@ -99,12 +94,9 @@ class SmoothL1LossGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext& ctx) const override { - auto in_dims = ctx.Input("X")->dims(); - auto out_dims = - ctx.Input(framework::GradVarName("Out"))->dims(); - auto* x_grad = ctx.Output(framework::GradVarName("X")); - auto* y_grad = ctx.Output(framework::GradVarName("Y")); + void InferShape(framework::InferShapeContextBase* ctx) const override { + auto in_dims = ctx->GetInputDim("X"); + auto out_dims = ctx->GetInputDim(framework::GradVarName("Out")); PADDLE_ENFORCE_GE(out_dims.size(), 2, "The tensor rank of Input(Out@Grad) should be 2."); @@ -114,8 +106,14 @@ class SmoothL1LossGradOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ(out_dims[1], 1, "The 2nd dimension of Input(Out@Grad) must be 1."); - if (x_grad) x_grad->Resize(in_dims); - if (y_grad) y_grad->Resize(in_dims); + auto x_grad_name = framework::GradVarName("X"); + auto y_grad_name = framework::GradVarName("Y"); + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, in_dims); + } + if (ctx->HasOutput(y_grad_name)) { + ctx->SetOutputDim(y_grad_name, in_dims); + } } }; diff --git a/paddle/operators/softmax_op.cc b/paddle/operators/softmax_op.cc index e15cfe485016552971924a40a172e74a90629dce..e353afee3e10247fbd5c7f4282c366cd1bc39552 100644 --- a/paddle/operators/softmax_op.cc +++ b/paddle/operators/softmax_op.cc @@ -22,22 +22,23 @@ class SoftmaxOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), - "Input(X) of SoftmaxOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Y"), - "Output(Y) of SoftmaxOp should not be null."); - - PADDLE_ENFORCE(ctx.Input("X")->dims().size() == 2UL, + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of SoftmaxOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Y"), + "Output(Y) of SoftmaxOp should not be null."); + + auto x_dims = ctx->GetInputDim("X"); + PADDLE_ENFORCE(x_dims.size() == 2UL, "The input of softmax op must be a matrix."); - ctx.Output("Y")->Resize(ctx.Input("X")->dims()); + ctx->SetOutputDim("Y", x_dims); } }; class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker { public: - SoftmaxOpMaker(framework::OpProto *proto, - framework::OpAttrChecker *op_checker) + SoftmaxOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "The input tensor of softmax. " @@ -68,16 +69,15 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) should be not null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Y")), - "Input(Y@GRAD) should be not null."); - PADDLE_ENFORCE_EQ(ctx.Input("Y")->dims(), - ctx.Input(framework::GradVarName("Y"))->dims(), + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should be not null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")), + "Input(Y@GRAD) should be not null."); + PADDLE_ENFORCE_EQ(ctx->GetInputDim("Y"), + ctx->GetInputDim(framework::GradVarName("Y")), "Input(Y) and its gradients should have a same shape."); - ctx.Output(framework::GradVarName("X")) - ->Resize(ctx.Input("X")->dims()); + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); } }; diff --git a/paddle/operators/split_op.cc b/paddle/operators/split_op.cc index a9d35b4fb79ae83379552ae2c2b4d694bd8f86dd..8640d1010ef6ae352a93ee2fd7b771a90c6efa5c 100644 --- a/paddle/operators/split_op.cc +++ b/paddle/operators/split_op.cc @@ -24,40 +24,42 @@ class SplitOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - // infershape - auto *in = ctx.Input("X"); - auto outs = ctx.MultiOutput("Out"); - size_t axis = static_cast(ctx.Attr("axis")); - size_t num = static_cast(ctx.Attr("num")); - std::vector sections = - static_cast>(ctx.Attr>("sections")); - const size_t n = outs.size(); + void InferShape(framework::InferShapeContextBase *ctx) const override { + auto in_dims = ctx->GetInputDim("X"); + auto outs_names = ctx->Outputs("Out"); + size_t axis = static_cast(ctx->Attrs().Get("axis")); + size_t num = static_cast(ctx->Attrs().Get("num")); + std::vector sections = static_cast>( + ctx->Attrs().Get>("sections")); + const size_t outs_number = outs_names.size(); + std::vector outs_dims; + outs_dims.reserve(outs_number); if (num > 0) { - int64_t in_axis_dim = in->dims()[axis]; + int64_t in_axis_dim = in_dims[axis]; PADDLE_ENFORCE_EQ(in_axis_dim % num, 0, "tensor split does not result" " in an equal division"); size_t out_axis_dim = in_axis_dim / num; - for (size_t i = 0; i < n; ++i) { - auto dim = in->dims(); + for (size_t i = 0; i < outs_number; ++i) { + auto dim = in_dims; dim[axis] = out_axis_dim; - outs[i]->Resize(dim); + outs_dims.push_back(dim); } } else if (sections.size() > 0) { - PADDLE_ENFORCE_EQ(sections.size(), n, + PADDLE_ENFORCE_EQ(sections.size(), outs_number, "tensor split sections size" "should be equal to output size."); - for (size_t i = 0; i < n; ++i) { - auto dim = in->dims(); + for (size_t i = 0; i < outs_number; ++i) { + auto dim = in_dims; dim[axis] = sections[i]; - outs[i]->Resize(dim); + outs_dims.push_back(dim); } } else { PADDLE_ENFORCE_NOT_NULL(nullptr, "split operator should", " specify indices or sections."); } + ctx->SetOutputsDim("Out", outs_dims); } }; diff --git a/paddle/operators/squared_l2_distance_op.cc b/paddle/operators/squared_l2_distance_op.cc index 33a564b05b1b490c6d23b7d17cef45b7740dfa39..5a0cb596008a98aacf5e7b5ff70307ea1b8508e6 100644 --- a/paddle/operators/squared_l2_distance_op.cc +++ b/paddle/operators/squared_l2_distance_op.cc @@ -22,24 +22,19 @@ class SquaredL2DistanceOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext& ctx) const override { - PADDLE_ENFORCE_NOT_NULL( - ctx.InputVar("X"), - "Input(X) of SquaredL2DistanceOp should not be null."); - PADDLE_ENFORCE_NOT_NULL( - ctx.InputVar("Y"), - "Input(Y) of SquaredL2DistanceOp should not be null."); - PADDLE_ENFORCE_NOT_NULL( - ctx.OutputVar("sub_result"), + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of SquaredL2DistanceOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Y"), + "Input(Y) of SquaredL2DistanceOp should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("sub_result"), "Output(sub_result) of SquaredL2DistanceOp should not be null."); - PADDLE_ENFORCE_NOT_NULL( - ctx.OutputVar("Out"), - "Output(Out) of SquaredL2DistanceOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of SquaredL2DistanceOp should not be null."); - auto* x = ctx.Input("X"); - auto x_dims = x->dims(); - auto* y = ctx.Input("Y"); - auto y_dims = y->dims(); + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); PADDLE_ENFORCE_EQ(framework::arity(x_dims), framework::arity(y_dims), "Tensor rank of both SquaredL2DistanceOp's " @@ -47,17 +42,16 @@ class SquaredL2DistanceOp : public framework::OperatorWithKernel { int rank = framework::arity(x_dims); PADDLE_ENFORCE_GE(rank, 2, "Tensor rank should be at least equal to 2."); - PADDLE_ENFORCE_EQ(x->numel() / x_dims[0], y->numel() / y_dims[0], + PADDLE_ENFORCE_EQ(product(x_dims) / x_dims[0], product(y_dims) / y_dims[0], "Product of dimensions expcet the first dimension of " "input and target must be equal."); PADDLE_ENFORCE(y_dims[0] == 1 || y_dims[0] == x_dims[0], "First dimension of target must be equal to input " "or to 1."); - ctx.Output("sub_result") - ->Resize({x_dims[0], x->numel() / x_dims[0]}); - ctx.Output("Out")->Resize({x_dims[0], 1}); - ctx.ShareLoD("X", /*->*/ "Out"); + ctx->SetOutputDim("sub_result", {x_dims[0], product(x_dims) / x_dims[0]}); + ctx->SetOutputDim("Out", {x_dims[0], 1}); + ctx->ShareLoD("X", /*->*/ "Out"); } }; @@ -92,22 +86,22 @@ class SquaredL2DistanceGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext& ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), - "Gradient of Out should not be null"); - auto out_dims = ctx.Input(framework::GradVarName("Out"))->dims(); - auto x_dims = ctx.Input("X")->dims(); - auto y_dims = ctx.Input("Y")->dims(); + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Gradient of Out should not be null"); + auto out_dims = ctx->GetInputDim(framework::GradVarName("Out")); + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); PADDLE_ENFORCE_EQ(out_dims[0], x_dims[0], "First dimension of output gradient and " "input value must be equal."); PADDLE_ENFORCE_EQ(out_dims[1], 1, "Second dimension of output gradient " "must be 1."); - auto* x_grad = ctx.Output(framework::GradVarName("X")); - auto* y_grad = ctx.Output(framework::GradVarName("Y")); - if (x_grad) x_grad->Resize(x_dims); - if (y_grad) y_grad->Resize(y_dims); + auto x_grad_name = framework::GradVarName("X"); + auto y_grad_name = framework::GradVarName("Y"); + if (ctx->HasOutput(x_grad_name)) ctx->SetOutputDim(x_grad_name, x_dims); + if (ctx->HasOutput(y_grad_name)) ctx->SetOutputDim(y_grad_name, y_dims); } }; diff --git a/paddle/operators/sum_op.cc b/paddle/operators/sum_op.cc index 437fc262f359525045a4d772ee2c204ef571caa7..8f62a9f4db8d39edc11949df513aebf4fa257d45 100644 --- a/paddle/operators/sum_op.cc +++ b/paddle/operators/sum_op.cc @@ -21,31 +21,27 @@ class SumOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE(!ctx.MultiInputVar("X").empty(), - "Input(X) of SumOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), - "Output(Out) of SumOp should not be null."); - - auto ins = ctx.MultiInput("X"); - auto *out = ctx.Output("Out"); - int N = ins.size(); - - auto in_dim = ins[0]->dims(); + void InferShape(framework::InferShapeContextBase* ctx) const override { + auto x_dims = ctx->GetInputsDim("X"); + PADDLE_ENFORCE(!x_dims.empty(), "Input(X) of SumOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of SumOp should not be null."); + auto in_dim = x_dims[0]; + size_t N = x_dims.size(); PADDLE_ENFORCE_GT(N, 1, "Input tensors count should > 1."); - for (int i = 1; i < N; i++) { - auto dim = ins[i]->dims(); + for (size_t i = 1; i < N; i++) { + auto dim = x_dims[i]; PADDLE_ENFORCE(in_dim == dim, "Input tensors must have same shape"); } - out->Resize(in_dim); - ctx.ShareLoD("X", /*->*/ "Out"); + ctx->SetOutputDim("Out", in_dim); + ctx->ShareLoD("X", /*->*/ "Out"); } }; class SumOpMaker : public framework::OpProtoAndCheckerMaker { public: - SumOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + SumOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "the input tensors of sum operator.").AsDuplicable(); AddOutput("Out", "the output tensor of sum operator."); @@ -63,13 +59,16 @@ class SumGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - auto outputs = - ctx.MultiOutput(framework::GradVarName("X")); - auto dims = ctx.Input(framework::GradVarName("Out"))->dims(); - for (auto output : outputs) { - output->Resize(dims); + void InferShape(framework::InferShapeContextBase* ctx) const override { + auto out_grad_dims = ctx->GetInputDim(framework::GradVarName("Out")); + auto x_grad_names = ctx->Outputs(framework::GradVarName("X")); + size_t x_length = x_grad_names.size(); + std::vector x_grad_dims; + x_grad_dims.reserve(x_length); + for (size_t i = 0; i < x_length; ++i) { + x_grad_dims.push_back(out_grad_dims); } + ctx->SetOutputsDim(framework::GradVarName("X"), x_grad_dims); } }; diff --git a/paddle/operators/top_k_op.cc b/paddle/operators/top_k_op.cc index a6e43964e9825cd1ced9e7c1bc8d691422248fee..5f22bf1df8720b60aba7cd75896d88cd1ad77635 100644 --- a/paddle/operators/top_k_op.cc +++ b/paddle/operators/top_k_op.cc @@ -22,26 +22,26 @@ class TopkOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), - "Input(X) of TopkOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), - "Output(Out) of TopkOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Indices"), - "Output(Indices) of TopkOp should not be null."); + void InferShape(framework::InferShapeContextBase *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of TopkOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of TopkOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Indices"), + "Output(Indices) of TopkOp should not be null."); - auto *input = ctx.Input("X"); - const int k = static_cast(ctx.Attr("k")); + auto input_dims = ctx->GetInputDim("X"); + const int k = static_cast(ctx->Attrs().Get("k")); PADDLE_ENFORCE_GE(k, 1, "k must >= 1"); - PADDLE_ENFORCE_GE(input->dims().size(), 1, "input must have >= 1d shape"); - PADDLE_ENFORCE_GE(input->dims()[input->dims().size() - 1], k, + PADDLE_ENFORCE_GE(input_dims.size(), 1, "input must have >= 1d shape"); + PADDLE_ENFORCE_GE(input_dims[input_dims.size() - 1], k, "input must have >= k columns"); - framework::DDim dims = input->dims(); + framework::DDim dims = input_dims; dims[dims.size() - 1] = k; - ctx.Output("Out")->Resize(dims); - ctx.Output("Indices")->Resize(dims); + ctx->SetOutputDim("Out", dims); + ctx->SetOutputDim("Indices", dims); } }; diff --git a/paddle/operators/transpose_op.cc b/paddle/operators/transpose_op.cc index 017a05326e9b397185d7c3530891884b11784783..0672f9342dac00ecc3f358450a9a203327cbb51f 100644 --- a/paddle/operators/transpose_op.cc +++ b/paddle/operators/transpose_op.cc @@ -24,12 +24,11 @@ class TransposeOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null"); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), - "Output(Out) should not be null"); - auto x_dims = ctx.Input("X")->dims(); - std::vector axis = ctx.Attr>("axis"); + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); + PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null"); + auto x_dims = ctx->GetInputDim("X"); + std::vector axis = ctx->Attrs().Get>("axis"); size_t x_rank = x_dims.size(); size_t axis_size = axis.size(); @@ -51,14 +50,14 @@ class TransposeOp : public framework::OperatorWithKernel { for (size_t i = 0; i < axis_size; i++) { out_dims[i] = x_dims[axis[i]]; } - ctx.Output("Out")->Resize(out_dims); + ctx->SetOutputDim("Out", out_dims); } }; class TransposeOpMaker : public framework::OpProtoAndCheckerMaker { public: - TransposeOpMaker(framework::OpProto *proto, - framework::OpAttrChecker *op_checker) + TransposeOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput( "X", @@ -79,7 +78,7 @@ For example: [3, 4, 5]]) >> axis = [1, 0] >> output = input.transpose(axis) - >> output + >> output array([[0, 3], [1, 4], [2, 5]]) @@ -94,14 +93,15 @@ class TransposeOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null"); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), - "Input(Out@GRAD) should not be null"); - auto x_dims = ctx.Input("X")->dims(); - auto *x_grad = ctx.Output(framework::GradVarName("X")); - - if (x_grad) x_grad->Resize(x_dims); + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null"); + auto x_dims = ctx->GetInputDim("X"); + ctx->SetOutputDim(framework::GradVarName("X"), x_dims); + if (ctx->HasOutput(framework::GradVarName("X"))) { + ctx->SetOutputDim(framework::GradVarName("X"), x_dims); + } } }; diff --git a/paddle/operators/uniform_random_op.cc b/paddle/operators/uniform_random_op.cc index 17ea48361bc597ccfeb80884d51900e6567aa057..2771df56086ff261728af84edcdf01cda3e45e9f 100644 --- a/paddle/operators/uniform_random_op.cc +++ b/paddle/operators/uniform_random_op.cc @@ -23,18 +23,18 @@ namespace operators { template class CPUUniformRandomKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& context) const override { - auto* tensor = context.Output("Out"); - T* data = tensor->mutable_data(context.GetPlace()); - unsigned int seed = static_cast(context.Attr("seed")); + void Compute(const framework::ExecutionContext& ctx) const override { + auto* tensor = ctx.Output("Out"); + T* data = tensor->mutable_data(ctx.GetPlace()); + unsigned int seed = static_cast(ctx.Attr("seed")); std::minstd_rand engine; if (seed == 0) { seed = std::random_device()(); } engine.seed(seed); std::uniform_real_distribution dist( - static_cast(context.Attr("min")), - static_cast(context.Attr("max"))); + static_cast(ctx.Attr("min")), + static_cast(ctx.Attr("max"))); int64_t size = tensor->numel(); for (int64_t i = 0; i < size; ++i) { data[i] = dist(engine); @@ -47,21 +47,20 @@ class UniformRandomOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext& ctx) const override { - PADDLE_ENFORCE_NOT_NULL( - ctx.OutputVar("Out"), - "Output(Out) of UniformRandomOp should not be null."); + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of UniformRandomOp should not be null."); - PADDLE_ENFORCE(Attr("min") < Attr("max"), - "uniform_random's min must less then max"); - auto* tensor = ctx.Output("Out"); + PADDLE_ENFORCE( + ctx->Attrs().Get("min") < ctx->Attrs().Get("max"), + "uniform_random's min must less then max"); auto dims = Attr>("dims"); std::vector temp; temp.reserve(dims.size()); for (auto dim : dims) { temp.push_back(static_cast(dim)); } - tensor->Resize(framework::make_ddim(temp)); + ctx->SetOutputDim("Out", framework::make_ddim(temp)); } };