提交 9a9d50a6 编写于 作者: Q Qiao Longfei 提交者: GitHub

Refactoring InferShape (#3946)

* init Infershape

* add static InferShape interface

* refactor add-op infershape

* add AttrReader

* add all maker's infershape

* add all InferShape

* add python infer api

* add VarDesc interface

* add python VarDesc and OpDesc interface

* update python code

* use infershape function to do shape inference

* clean code

* do not use pointer

* refine code of op_proto_maker

* add get_dims to VarDesc

* refine the code

* remove the dependency from operator to op registry

* remove OpProtoAndCheckerMaker from operator

* restore complete_add_op

* add shape_infer_impl.h

* code optimization

* remove const return value

* add fake BlockDesc class

* optimize code

* remove infer function in op_info

* move InferShapeContextImpl to operator.h

* optimize the interface of InferShapeContextBase

* add temperary interface of new infershape

* change add_op, clip_op, conv2d_op and activation_op

* change all operators InferShape

* fix SetDim

* update cos_sim_op

* update crop_op

* update lookup_table_op

* allocate tensor when call GetDim in InferShapeContext

* update modified_huber_loss_op

* update rowwise_add_op

* update mean_op

* update sequence_avg_pool_op

* typo

* remove old InferShape interface

* can compile

* fix or unit test

* clean code

* clean code

* remove const before InferShapeContext

* change InferenceContextBase to pointer

* rename RunTime to Runtime, code clean
上级 86351037
...@@ -45,6 +45,21 @@ inline AttrType AttrTypeID() { ...@@ -45,6 +45,21 @@ inline AttrType AttrTypeID() {
Attribute GetAttrValue(const OpDesc::Attr& attr_desc); Attribute GetAttrValue(const OpDesc::Attr& attr_desc);
class AttrReader {
public:
explicit AttrReader(const AttributeMap& attrs) : attrs_(attrs) {}
template <typename T>
inline const T& Get(const std::string& name) const {
PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should be in AttributeMap",
name);
return boost::get<T>(attrs_.at(name));
}
private:
const AttributeMap& attrs_;
};
// check whether a value(attribute) fit a certain limit // check whether a value(attribute) fit a certain limit
template <typename T> template <typename T>
class GreaterThanChecker { class GreaterThanChecker {
......
...@@ -174,4 +174,4 @@ TEST(OpRegistry, CustomChecker) { ...@@ -174,4 +174,4 @@ TEST(OpRegistry, CustomChecker) {
op->Run(scope, dev_ctx); op->Run(scope, dev_ctx);
int test_attr = op->Attr<int>("test_attr"); int test_attr = op->Attr<int>("test_attr");
ASSERT_EQ(test_attr, 4); ASSERT_EQ(test_attr, 4);
} }
\ No newline at end of file
...@@ -14,7 +14,6 @@ limitations under the License. */ ...@@ -14,7 +14,6 @@ limitations under the License. */
#include "paddle/framework/operator.h" #include "paddle/framework/operator.h"
#include <algorithm> #include <algorithm>
#include "paddle/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -33,6 +32,24 @@ ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const { ...@@ -33,6 +32,24 @@ ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
} }
#endif #endif
const Tensor* GetTensorFromVar(const Variable* var) {
if (var->IsType<LoDTensor>()) {
return &var->Get<LoDTensor>();
}
PADDLE_ENFORCE(var->IsType<Tensor>(),
"The Input must be LoDTensor or Tensor.");
return &var->Get<Tensor>();
}
Tensor* GetTensorFromVar(Variable* var) {
if (var->IsType<LoDTensor>()) {
return var->GetMutable<LoDTensor>();
}
PADDLE_ENFORCE(var->IsType<Tensor>(),
"The Input must be LoDTensor or Tensor.");
return var->GetMutable<Tensor>();
}
std::string OperatorBase::Input(const std::string& name) const { std::string OperatorBase::Input(const std::string& name) const {
auto& ins = Inputs(name); auto& ins = Inputs(name);
PADDLE_ENFORCE_LE(ins.size(), 1UL, PADDLE_ENFORCE_LE(ins.size(), 1UL,
......
...@@ -24,6 +24,7 @@ limitations under the License. */ ...@@ -24,6 +24,7 @@ limitations under the License. */
#include "paddle/framework/framework.pb.h" #include "paddle/framework/framework.pb.h"
#include "paddle/framework/lod_tensor.h" #include "paddle/framework/lod_tensor.h"
#include "paddle/framework/scope.h" #include "paddle/framework/scope.h"
#include "paddle/framework/shape_inference.h"
#include "paddle/framework/tensor.h" #include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h" #include "paddle/platform/device_context.h"
#include "paddle/platform/place.h" #include "paddle/platform/place.h"
...@@ -56,6 +57,9 @@ class OperatorBase; ...@@ -56,6 +57,9 @@ class OperatorBase;
class InferShapeContext; class InferShapeContext;
class ExecutionContext; class ExecutionContext;
extern const Tensor* GetTensorFromVar(const Variable* var);
extern Tensor* GetTensorFromVar(Variable* var);
/** /**
* OperatorBase has the basic element that Net will call to do computation. * OperatorBase has the basic element that Net will call to do computation.
* Only CreateOperator from OpRegistry will new Operator directly. User * Only CreateOperator from OpRegistry will new Operator directly. User
...@@ -262,15 +266,6 @@ class InferShapeContext { ...@@ -262,15 +266,6 @@ class InferShapeContext {
return res; return res;
} }
const Tensor* GetTensorFromVar(const Variable* var) const {
if (var->IsType<LoDTensor>()) {
return &var->Get<LoDTensor>();
}
PADDLE_ENFORCE(var->IsType<Tensor>(),
"The Input(%s) must be LoDTensor or Tensor.");
return &var->Get<Tensor>();
}
void ShareLoD(const std::string& in, const std::string& out, size_t i = 0, void ShareLoD(const std::string& in, const std::string& out, size_t i = 0,
size_t j = 0) const { size_t j = 0) const {
PADDLE_ENFORCE_LT(i, InputSize(in)); PADDLE_ENFORCE_LT(i, InputSize(in));
...@@ -340,6 +335,78 @@ class ExecutionContext : public InferShapeContext { ...@@ -340,6 +335,78 @@ class ExecutionContext : public InferShapeContext {
const platform::DeviceContext& device_context_; const platform::DeviceContext& device_context_;
}; };
class RuntimeInferShapeContext : public InferShapeContextBase {
public:
RuntimeInferShapeContext(const OperatorBase& op, const Scope& scope)
: op_(op), scope_(scope) {}
bool HasInput(const std::string& name) const {
auto ipt = op_.Input(name);
auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
return var != nullptr;
}
bool HasOutput(const std::string& name) const {
auto ipt = op_.Output(name);
auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
return var != nullptr;
}
DDim GetInputDim(const std::string& name) const {
return GetDim(op_.Input(name));
}
void SetInputDim(const std::string& name, const DDim& dim) {
SetDim(op_.Input(name), dim);
}
DDim GetOutputDim(const std::string& name) const {
return GetDim(op_.Output(name));
}
void SetOutputDim(const std::string& name, const DDim& dim) {
SetDim(op_.Output(name), dim);
}
AttrReader Attrs() const { return AttrReader(op_.Attrs()); }
const std::vector<std::string>& Inputs(const std::string& name) const {
return op_.Inputs(name);
}
const std::vector<std::string>& Outputs(const std::string& name) const {
return op_.Outputs(name);
}
private:
template <bool Allocate>
Tensor* GetTensor(const std::string& name) const {
Tensor* t = nullptr;
auto* var = scope_.FindVar(name);
if (!var->IsType<LoDTensor>() && !var->IsType<Tensor>()) {
if (Allocate) {
t = var->GetMutable<LoDTensor>();
} else {
PADDLE_THROW("Variable(%s) should be tensor", name);
}
} else {
t = GetTensorFromVar(scope_.FindVar(name));
}
return t;
}
DDim GetDim(const std::string& name) const {
return GetTensor<false>(name)->dims();
}
void SetDim(const std::string& name, const DDim& dim) {
GetTensor<true>(name)->Resize(dim);
}
const OperatorBase& op_;
const Scope& scope_;
};
class OpKernel { class OpKernel {
public: public:
/** /**
...@@ -383,8 +450,10 @@ class OperatorWithKernel : public OperatorBase { ...@@ -383,8 +450,10 @@ class OperatorWithKernel : public OperatorBase {
const VariableNameMap& outputs, const AttributeMap& attrs) const VariableNameMap& outputs, const AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
// runtime infershape
void InferShape(const Scope& scope) const override { void InferShape(const Scope& scope) const override {
InferShape(InferShapeContext(*this, scope)); auto c = RuntimeInferShapeContext(*this, scope);
InferShape(&c);
} }
void Run(const Scope& scope, void Run(const Scope& scope,
...@@ -406,7 +475,7 @@ class OperatorWithKernel : public OperatorBase { ...@@ -406,7 +475,7 @@ class OperatorWithKernel : public OperatorBase {
} }
protected: protected:
virtual void InferShape(const InferShapeContext& ctx) const = 0; virtual void InferShape(InferShapeContextBase* ctx) const = 0;
}; };
} // namespace framework } // namespace framework
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/framework/operator.h" #include "paddle/framework/operator.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/framework/op_info.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
namespace paddle { namespace paddle {
...@@ -114,7 +115,7 @@ class OpWithKernelTest : public OperatorWithKernel { ...@@ -114,7 +115,7 @@ class OpWithKernelTest : public OperatorWithKernel {
using OperatorWithKernel::OperatorWithKernel; using OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext& ctx) const override {} void InferShape(framework::InferShapeContextBase* ctx) const override {}
}; };
template <typename T1, typename T2> template <typename T1, typename T2>
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/ddim.h"
namespace paddle {
namespace framework {
class InferShapeContextBase {
public:
virtual ~InferShapeContextBase() {}
virtual bool HasInput(const std::string &name) const = 0;
virtual bool HasOutput(const std::string &name) const = 0;
virtual framework::DDim GetInputDim(const std::string &name) const = 0;
std::vector<framework::DDim> GetInputsDim(const std::string &name) const {
const std::vector<std::string> &names = Inputs(name);
return GetDims(names);
}
virtual void SetInputDim(const std::string &name,
const framework::DDim &dim) = 0;
void SetInputsDim(const std::string &name,
const std::vector<framework::DDim> &dims) {
auto &names = Inputs(name);
SetDims(names, dims);
}
virtual framework::DDim GetOutputDim(const std::string &name) const = 0;
std::vector<framework::DDim> GetOutputsDim(const std::string &name) const {
const std::vector<std::string> &names = Outputs(name);
return GetDims(names);
}
virtual void SetOutputDim(const std::string &name, const DDim &dim) = 0;
void SetOutputsDim(const std::string &name,
const std::vector<framework::DDim> &dims) {
auto &names = Outputs(name);
SetDims(names, dims);
}
virtual AttrReader Attrs() const = 0;
virtual const std::vector<std::string> &Inputs(
const std::string &name) const = 0;
virtual const std::vector<std::string> &Outputs(
const std::string &name) const = 0;
// TODO(qiao) implement this function
void ShareLoD(const std::string &in, const std::string &out, size_t i = 0,
size_t j = 0) const {}
protected:
virtual framework::DDim GetDim(const std::string &name) const = 0;
virtual void SetDim(const std::string &name, const framework::DDim &dim) = 0;
std::vector<framework::DDim> GetDims(
const std::vector<std::string> &names) const {
std::vector<framework::DDim> ret;
ret.reserve(names.size());
std::transform(
names.begin(), names.end(), std::back_inserter(ret),
[this](const std::string &name) { return this->GetDim(name); });
return ret;
}
void SetDims(const std::vector<std::string> &names,
const std::vector<framework::DDim> &dims) {
size_t length = names.size();
PADDLE_ENFORCE_EQ(length, dims.size());
for (size_t i = 0; i < length; ++i) {
SetDim(names[i], dims[i]);
}
}
};
} // namespace framework
} // namespace paddle
...@@ -22,25 +22,23 @@ class AccuracyOp : public framework::OperatorWithKernel { ...@@ -22,25 +22,23 @@ class AccuracyOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(framework::InferShapeContextBase *ctx) const override {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE(ctx->HasInput("Inference"),
ctx.InputVar("Inference"), "Input(Inference) of AccuracyOp should not be null.");
"Input(Inference) of AccuracyOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("Label"),
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), "Input(Label) of AccuracyOp should not be null.");
"Input(Label) of AccuracyOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Accuracy"),
PADDLE_ENFORCE_NOT_NULL( "Output(Accuracy) of AccuracyOp should not be null.");
ctx.OutputVar("Accuracy"),
"Output(Accuracy) of AccuracyOp should not be null.");
auto *inference = ctx.Input<framework::Tensor>("Inference"); auto inference_dim = ctx->GetInputDim("Inference");
auto *label = ctx.Input<framework::Tensor>("Label"); auto label_dim = ctx->GetInputDim("Label");
PADDLE_ENFORCE_EQ(label->dims().size(), 1, "label must be a vector"); PADDLE_ENFORCE_EQ(label_dim.size(), 1, "label must be a vector");
PADDLE_ENFORCE_EQ(inference->dims()[0], label->dims()[0], PADDLE_ENFORCE_EQ(inference_dim[0], label_dim[0],
"inference size must be the same as label size"); "inference size must be the same as label size");
ctx.Output<framework::Tensor>("Accuracy")->Resize({1}); ctx->SetOutputDim("Accuracy", {1});
ctx.ShareLoD("Inference", /*->*/ "Accuracy"); ctx->ShareLoD("Inference", /*->*/ "Accuracy");
} }
}; };
......
...@@ -22,10 +22,9 @@ class ActivationOp : public framework::OperatorWithKernel { ...@@ -22,10 +22,9 @@ class ActivationOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(framework::InferShapeContextBase *ctx) const override {
ctx.Output<framework::Tensor>("Y")->Resize( ctx->SetOutputDim("Y", ctx->GetInputDim("X"));
ctx.Input<framework::Tensor>("X")->dims()); ctx->ShareLoD("X", /*->*/ "Y");
ctx.ShareLoD("X", /*->*/ "Y");
} }
}; };
...@@ -34,9 +33,8 @@ class ActivationOpGrad : public framework::OperatorWithKernel { ...@@ -34,9 +33,8 @@ class ActivationOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(framework::InferShapeContextBase *ctx) const override {
ctx.Output<framework::Tensor>(framework::GradVarName("X")) ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Y"));
->Resize(ctx.Input<framework::Tensor>("Y")->dims());
} }
}; };
......
...@@ -22,25 +22,23 @@ class AddOp : public framework::OperatorWithKernel { ...@@ -22,25 +22,23 @@ class AddOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of AddOp should not be null.");
"Input(X) of AddOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) of AddOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Input(Y) of AddOp should not be null."); "Output(Out) of AddOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"),
"Output(Out) of AddOp should not be null.");
PADDLE_ENFORCE_EQ(ctx.Input<Tensor>("X")->dims(), auto x_dims = ctx->GetInputDim("X");
ctx.Input<Tensor>("Y")->dims(), auto y_dims = ctx->GetInputDim("Y");
PADDLE_ENFORCE_EQ(x_dims, y_dims,
"Two input of Add Op's dimension must be same."); "Two input of Add Op's dimension must be same.");
ctx.Output<framework::Tensor>("Out")->Resize( ctx->SetOutputDim("Out", x_dims);
ctx.Input<Tensor>("X")->dims());
} }
}; };
class AddOpMaker : public framework::OpProtoAndCheckerMaker { class AddOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
AddOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) AddOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The first input of add op"); AddInput("X", "The first input of add op");
AddInput("Y", "The second input of add op"); AddInput("Y", "The second input of add op");
...@@ -58,7 +56,7 @@ class AddOpGrad : public framework::OperatorWithKernel { ...@@ -58,7 +56,7 @@ class AddOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override {} void InferShape(framework::InferShapeContextBase* ctx) const override {}
}; };
} // namespace operators } // namespace operators
......
...@@ -22,24 +22,24 @@ class ClipOp : public framework::OperatorWithKernel { ...@@ -22,24 +22,24 @@ class ClipOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of ClipOp should not be null."); "Input(X) of ClipOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of ClipOp should not be null."); "Output(Out) of ClipOp should not be null.");
auto x_dims = ctx.Input<Tensor>("X")->dims(); auto x_dims = ctx->GetInputDim("X");
auto max = Attr<float>("max"); auto max = ctx->Attrs().Get<float>("max");
auto min = Attr<float>("min"); auto min = ctx->Attrs().Get<float>("min");
PADDLE_ENFORCE_LT(min, max, "max should be greater than min."); PADDLE_ENFORCE_LT(min, max, "max should be greater than min.");
ctx.Output<Tensor>("Out")->Resize(x_dims); ctx->SetOutputDim("Out", x_dims);
ctx.ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
}; };
template <typename AttrType> template <typename AttrType>
class ClipOpMaker : public framework::OpProtoAndCheckerMaker { class ClipOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ClipOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) ClipOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(Tensor)The input of clip op." "(Tensor)The input of clip op."
...@@ -61,14 +61,13 @@ class ClipOpGrad : public framework::OperatorWithKernel { ...@@ -61,14 +61,13 @@ class ClipOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null"); PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null"); "Input(Out@GRAD) should not be null");
auto x_dims = ctx.Input<Tensor>("X")->dims(); auto x_dims = ctx->GetInputDim("X");
auto *x_grad = ctx.Output<Tensor>(framework::GradVarName("X")); if (ctx->HasOutput(framework::GradVarName("X"))) {
if (x_grad != nullptr) { ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
x_grad->Resize(x_dims);
} }
} }
}; };
......
...@@ -24,31 +24,30 @@ class ConcatOp : public framework::OperatorWithKernel { ...@@ -24,31 +24,30 @@ class ConcatOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(framework::InferShapeContextBase *ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of ConcatOp should not be null."); "Output(Out) of ConcatOp should not be null.");
auto ins = ctx.MultiInput<framework::Tensor>("X"); auto ins = ctx->GetInputsDim("X");
auto *out = ctx.Output<framework::Tensor>("Out"); size_t axis = static_cast<size_t>(ctx->Attrs().Get<int>("axis"));
size_t axis = static_cast<size_t>(ctx.Attr<int>("axis"));
size_t n = ins.size(); size_t n = ins.size();
PADDLE_ENFORCE_GT(n, 1, "Input tensors count should > 1."); PADDLE_ENFORCE_GT(n, 1, "Input tensors count should > 1.");
auto out_dims = ins[0]->dims(); auto out_dims = ins[0];
size_t in_zero_dims_size = out_dims.size(); size_t in_zero_dims_size = out_dims.size();
for (size_t i = 1; i < n; i++) { for (size_t i = 1; i < n; i++) {
for (size_t j = 0; j < in_zero_dims_size; j++) { for (size_t j = 0; j < in_zero_dims_size; j++) {
if (j == axis) { if (j == axis) {
out_dims[axis] += ins[i]->dims()[j]; out_dims[axis] += ins[i][j];
continue; continue;
} }
PADDLE_ENFORCE_EQ(out_dims[j], ins[i]->dims()[j], PADDLE_ENFORCE_EQ(out_dims[j], ins[i][j],
"Input tensors should have the same " "Input tensors should have the same "
"elements except the specify axis.") "elements except the specify axis.")
} }
} }
out->Resize(out_dims); ctx->SetOutputDim("Out", out_dims);
} }
}; };
......
...@@ -215,7 +215,7 @@ class CondOpProtoAndCheckerMaker : public framework::OpProtoAndCheckerMaker { ...@@ -215,7 +215,7 @@ class CondOpProtoAndCheckerMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC( AddComment(R"DOC(
Sample dependent Cond Operator: Sample dependent Cond Operator:
Given Cond[i] as a 1/0 vector to indicate true/false Given Cond[i] as a 1/0 vector to indicate true/false
The equation is: The equation is:
Out[i] = subnet_t[i], if Cond[i] == true Out[i] = subnet_t[i], if Cond[i] == true
Out[i] = subnet_t[i], if Cond[i] == false Out[i] = subnet_t[i], if Cond[i] == false
)DOC"); )DOC");
......
...@@ -27,27 +27,25 @@ class Conv2DOp : public framework::OperatorWithKernel { ...@@ -27,27 +27,25 @@ class Conv2DOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Input"), PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input(Input) of Conv2DOp should not be null."); "Input(Input) of Conv2DOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Filter"), PADDLE_ENFORCE(ctx->HasInput("Filter"),
"Input(Filter) of Conv2DOp should not be null."); "Input(Filter) of Conv2DOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Output"), PADDLE_ENFORCE(ctx->HasOutput("Output"),
"Output(Output) of Conv2DOp should not be null."); "Output(Output) of Conv2DOp should not be null.");
auto in = ctx.Input<Tensor>("Input"); auto in_dims = ctx->GetInputDim("Input");
auto filter = ctx.Input<Tensor>("Filter"); auto filter_dims = ctx->GetInputDim("Filter");
auto out = ctx.Output<framework::Tensor>("Output"); std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
std::vector<int> strides = Attr<std::vector<int>>("strides"); std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
std::vector<int> paddings = Attr<std::vector<int>>("paddings"); int groups = ctx->Attrs().Get<int>("groups");
int groups = Attr<int>("groups"); int input_channels = in_dims[1];
int input_channels = in->dims()[1]; int output_channels = filter_dims[0];
int output_channels = filter->dims()[0];
PADDLE_ENFORCE_EQ(in_dims.size(), 4, "Conv2DOp input should be 4-D.");
PADDLE_ENFORCE_EQ(in->dims().size(), 4, "Conv2DOp input should be 4-D."); PADDLE_ENFORCE_EQ(filter_dims.size(), 4, "Conv2DOp filter should be 4-D.");
PADDLE_ENFORCE_EQ(filter->dims().size(), 4, PADDLE_ENFORCE_EQ(input_channels, filter_dims[1] * groups,
"Conv2DOp filter should be 4-D.");
PADDLE_ENFORCE_EQ(input_channels, filter->dims()[1] * groups,
"The number of input channels should be equal to filter " "The number of input channels should be equal to filter "
"channels * groups."); "channels * groups.");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
...@@ -55,17 +53,17 @@ class Conv2DOp : public framework::OperatorWithKernel { ...@@ -55,17 +53,17 @@ class Conv2DOp : public framework::OperatorWithKernel {
"The number of output channels should be divided by groups."); "The number of output channels should be divided by groups.");
auto output_height = auto output_height =
outputSize(in->dims()[2], filter->dims()[2], paddings[0], strides[0]); outputSize(in_dims[2], filter_dims[2], paddings[0], strides[0]);
auto output_width = auto output_width =
outputSize(in->dims()[3], filter->dims()[3], paddings[1], strides[1]); outputSize(in_dims[3], filter_dims[3], paddings[1], strides[1]);
out->Resize( ctx->SetOutputDim(
{in->dims()[0], filter->dims()[0], output_height, output_width}); "Output", {in_dims[0], filter_dims[0], output_height, output_width});
} }
}; };
class Conv2DOpMaker : public framework::OpProtoAndCheckerMaker { class Conv2DOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
Conv2DOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) Conv2DOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput( AddInput(
"Input", "Input",
...@@ -108,14 +106,15 @@ class Conv2DOpGrad : public framework::OperatorWithKernel { ...@@ -108,14 +106,15 @@ class Conv2DOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(framework::InferShapeContextBase* ctx) const override {
auto in = ctx.Input<Tensor>("Input"); auto in_dims = ctx->GetInputDim("Input");
auto filter = ctx.Input<Tensor>("Filter"); auto filter_dims = ctx->GetInputDim("Filter");
auto d_in = ctx.Output<framework::Tensor>(framework::GradVarName("Input")); if (ctx->HasOutput(framework::GradVarName("Input"))) {
auto d_filter = ctx->SetOutputDim(framework::GradVarName("Input"), in_dims);
ctx.Output<framework::Tensor>(framework::GradVarName("Filter")); }
if (d_in) d_in->Resize(in->dims()); if (ctx->HasOutput(framework::GradVarName("Filter"))) {
if (d_filter) d_filter->Resize(filter->dims()); ctx->SetOutputDim(framework::GradVarName("Filter"), filter_dims);
}
} }
}; };
......
...@@ -24,22 +24,22 @@ class CosSimOp : public framework::OperatorWithKernel { ...@@ -24,22 +24,22 @@ class CosSimOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(framework::InferShapeContextBase* ctx) const override {
// notnull check // notnull check
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of CosSimOp should not be null."); "Input(X) of CosSimOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), PADDLE_ENFORCE(ctx->HasInput("Y"),
"Input(Y) of CosSimOp should not be null."); "Input(Y) of CosSimOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of CosSimOp should not be null."); "Output(Out) of CosSimOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("XNorm"), PADDLE_ENFORCE(ctx->HasOutput("XNorm"),
"Output(XNorm) of CosSimOp should not be null."); "Output(XNorm) of CosSimOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("YNorm"), PADDLE_ENFORCE(ctx->HasOutput("YNorm"),
"Output(YNorm) of CosSimOp should not be null."); "Output(YNorm) of CosSimOp should not be null.");
// shape check // shape check
auto x_dims = ctx.Input<Tensor>("X")->dims(); auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx.Input<Tensor>("Y")->dims(); auto y_dims = ctx->GetInputDim("Y");
PADDLE_ENFORCE_EQ(x_dims.size(), y_dims.size(), PADDLE_ENFORCE_EQ(x_dims.size(), y_dims.size(),
"Ranks of Input(X) and Input(Y) must be equal."); "Ranks of Input(X) and Input(Y) must be equal.");
...@@ -54,16 +54,16 @@ class CosSimOp : public framework::OperatorWithKernel { ...@@ -54,16 +54,16 @@ class CosSimOp : public framework::OperatorWithKernel {
" just 1 (which will be broadcasted to match Input(X))."); " just 1 (which will be broadcasted to match Input(X)).");
// resize tensor // resize tensor
ctx.Output<framework::Tensor>("Out")->Resize({x_dims[0], 1}); ctx->SetOutputDim("Out", {x_dims[0], 1});
ctx.Output<framework::Tensor>("XNorm")->Resize({x_dims[0], 1}); ctx->SetOutputDim("XNorm", {x_dims[0], 1});
ctx.Output<framework::Tensor>("YNorm")->Resize({y_dims[0], 1}); ctx->SetOutputDim("YNorm", {y_dims[0], 1});
ctx.ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
}; };
class CosSimOpMaker : public framework::OpProtoAndCheckerMaker { class CosSimOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
CosSimOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) CosSimOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The 1st input of cos_sim op."); AddInput("X", "The 1st input of cos_sim op.");
AddInput("Y", "The 2nd input of cos_sim op."); AddInput("Y", "The 2nd input of cos_sim op.");
...@@ -98,27 +98,23 @@ class CosSimOpGrad : public framework::OperatorWithKernel { ...@@ -98,27 +98,23 @@ class CosSimOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(framework::InferShapeContextBase* ctx) const override {
// notnull check // notnull check
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null."); PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) must not be null."); PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) must not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("XNorm"), PADDLE_ENFORCE(ctx->HasInput("XNorm"), "Input(XNorm) must not be null.");
"Input(XNorm) must not be null."); PADDLE_ENFORCE(ctx->HasInput("YNorm"), "Input(YNorm) must not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("YNorm"), PADDLE_ENFORCE(ctx->HasInput("Out"), "Input(Out) must not be null.");
"Input(YNorm) must not be null."); PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Out"), "Input(Out@GRAD) must not be null.");
"Input(Out) must not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
"Input(Out@GRAD) must not be null.");
// shape check // shape check
auto x_dims = ctx.Input<Tensor>("X")->dims(); auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx.Input<Tensor>("Y")->dims(); auto y_dims = ctx->GetInputDim("Y");
auto xnorm_dims = ctx.Input<Tensor>("XNorm")->dims(); auto xnorm_dims = ctx->GetInputDim("XNorm");
auto ynorm_dims = ctx.Input<Tensor>("YNorm")->dims(); auto ynorm_dims = ctx->GetInputDim("YNorm");
auto out_dims = ctx.Input<Tensor>("Out")->dims(); auto out_dims = ctx->GetInputDim("Out");
auto out_grad_dims = auto out_grad_dims = ctx->GetInputDim(framework::GradVarName("Out"));
ctx.Input<Tensor>(framework::GradVarName("Out"))->dims();
PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(), PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(),
"Ranks of Input(X) and Input(Y) must be equal."); "Ranks of Input(X) and Input(Y) must be equal.");
...@@ -143,10 +139,14 @@ class CosSimOpGrad : public framework::OperatorWithKernel { ...@@ -143,10 +139,14 @@ class CosSimOpGrad : public framework::OperatorWithKernel {
"Shape of Input(Out@Grad) must be [X.Dim(0), 1]."); "Shape of Input(Out@Grad) must be [X.Dim(0), 1].");
// resize tensor // resize tensor
auto *x_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X")); auto x_grad_name = framework::GradVarName("X");
auto *y_grad = ctx.Output<framework::Tensor>(framework::GradVarName("Y")); auto y_grad_name = framework::GradVarName("Y");
if (x_grad) x_grad->Resize(x_dims); if (ctx->HasOutput(x_grad_name)) {
if (y_grad) y_grad->Resize(y_dims); ctx->SetOutputDim(x_grad_name, x_dims);
}
if (ctx->HasOutput(y_grad_name)) {
ctx->SetOutputDim(y_grad_name, y_dims);
}
} }
}; };
......
...@@ -25,16 +25,14 @@ class CropOp : public framework::OperatorWithKernel { ...@@ -25,16 +25,14 @@ class CropOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of CropOp should not be null."); "Input(X) of CropOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of CropOp should not be null."); "Output(Out) of CropOp should not be null.");
auto x_dim = ctx.Input<Tensor>("X")->dims(); auto x_dim = ctx->GetInputDim("X");
auto *y = ctx.Input<Tensor>("Y"); if (!ctx->HasInput("Y")) {
auto *out = ctx.Output<Tensor>("Out"); auto shape = ctx->Attrs().Get<std::vector<int>>("shape");
if (y == nullptr) {
auto shape = Attr<std::vector<int>>("shape");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
int64_t(shape.size()), x_dim.size(), int64_t(shape.size()), x_dim.size(),
"Shape size should be equal to dimention size of input tensor."); "Shape size should be equal to dimention size of input tensor.");
...@@ -42,19 +40,20 @@ class CropOp : public framework::OperatorWithKernel { ...@@ -42,19 +40,20 @@ class CropOp : public framework::OperatorWithKernel {
for (size_t i = 0; i < shape.size(); ++i) { for (size_t i = 0; i < shape.size(); ++i) {
tensor_shape[i] = static_cast<int64_t>(shape[i]); tensor_shape[i] = static_cast<int64_t>(shape[i]);
} }
out->Resize(framework::make_ddim(tensor_shape)); ctx->SetOutputDim("Out", framework::make_ddim(tensor_shape));
} else { } else {
PADDLE_ENFORCE_EQ(framework::arity(x_dim), framework::arity(y->dims()), auto y_dim = ctx->GetInputDim("Y");
PADDLE_ENFORCE_EQ(framework::arity(x_dim), framework::arity(y_dim),
"Tensor rank of both CropOp's " "Tensor rank of both CropOp's "
"inputs must be same."); "inputs must be same.");
out->Resize(y->dims()); ctx->SetOutputDim("Out", y_dim);
} }
} }
}; };
class CropOpMaker : public framework::OpProtoAndCheckerMaker { class CropOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
CropOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) CropOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"The input of pad op. " "The input of pad op. "
...@@ -78,12 +77,12 @@ class CropOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -78,12 +77,12 @@ class CropOpMaker : public framework::OpProtoAndCheckerMaker {
Crop Operator. Crop Operator.
Crop input into output, as specified by offsets and shape. Crop input into output, as specified by offsets and shape.
There are two ways to set shape: There are two ways to set shape:
1. referenc input: crop input X as shape as reference input. 1. referenc input: crop input X as shape as reference input.
The dimension of reference input should The dimension of reference input should
be as same as input X. be as same as input X.
2. shape list: crop input X by shape described by a list<int>. 2. shape list: crop input X by shape described by a list<int>.
The size of shape list should be as same as The size of shape list should be as same as
dimension size of input X. dimension size of input X.
The input should be a k-D tensor(k > 0 and k < 7). As an example: The input should be a k-D tensor(k > 0 and k < 7). As an example:
...@@ -94,15 +93,15 @@ Given: ...@@ -94,15 +93,15 @@ Given:
[0, 3, 4, 0, 0] [0, 3, 4, 0, 0]
[0, 0, 0, 0, 0]] [0, 0, 0, 0, 0]]
and and
offsets = [0, 1] offsets = [0, 1]
and and
shape = [2, 2] shape = [2, 2]
then we get then we get
Out = [[1, 2], Out = [[1, 2],
[3, 4]] [3, 4]]
...@@ -116,14 +115,14 @@ class CropOpGrad : public framework::OperatorWithKernel { ...@@ -116,14 +115,14 @@ class CropOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null"); PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null"); "Input(Out@GRAD) should not be null");
auto x_dims = ctx.Input<Tensor>("X")->dims(); auto x_dims = ctx->GetInputDim("X");
auto *x_grad = ctx.Output<Tensor>(framework::GradVarName("X")); auto x_grad_name = framework::GradVarName("X");
if (x_grad != nullptr) { if (ctx->HasOutput(x_grad_name)) {
x_grad->Resize(x_dims); ctx->SetOutputDim(x_grad_name, x_dims);
} }
} }
}; };
......
...@@ -22,33 +22,30 @@ class CrossEntropyOp : public framework::OperatorWithKernel { ...@@ -22,33 +22,30 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should be not null."); PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");
"Input(Label) should be not null."); PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) should be not null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Y"),
"Output(Y) should be not null."); auto x_dims = ctx->GetInputDim("X");
auto label_dims = ctx->GetInputDim("Label");
auto x = ctx.Input<Tensor>("X"); PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2.");
auto label = ctx.Input<Tensor>("Label"); PADDLE_ENFORCE_EQ(label_dims.size(), 2, "Input(Label)'s rank should be 2.");
PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank should be 2."); PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0],
PADDLE_ENFORCE_EQ(label->dims().size(), 2,
"Input(Label)'s rank should be 2.");
PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0],
"The 1st dimension of Input(X) and Input(Label) should " "The 1st dimension of Input(X) and Input(Label) should "
"be equal."); "be equal.");
if (ctx.Attr<bool>("softLabel")) { if (ctx->Attrs().Get<bool>("softLabel")) {
PADDLE_ENFORCE_EQ(x->dims()[1], label->dims()[1], PADDLE_ENFORCE_EQ(x_dims[1], label_dims[1],
"If Attr(softLabel) == true, the 2nd dimension of " "If Attr(softLabel) == true, the 2nd dimension of "
"Input(X) and Input(Label) should be equal."); "Input(X) and Input(Label) should be equal.");
} else { } else {
PADDLE_ENFORCE_EQ(label->dims()[1], 1, PADDLE_ENFORCE_EQ(label_dims[1], 1,
"If Attr(softLabel) == false, the 2nd dimension of " "If Attr(softLabel) == false, the 2nd dimension of "
"Input(Label) should be 1."); "Input(Label) should be 1.");
} }
ctx.Output<Tensor>("Y")->Resize({x->dims()[0], 1}); ctx->SetOutputDim("Y", {x_dims[0], 1});
ctx.ShareLoD("X", /*->*/ "Y"); ctx->ShareLoD("X", /*->*/ "Y");
} }
}; };
...@@ -57,50 +54,45 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel { ...@@ -57,50 +54,45 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should be not null."); PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");
"Input(Label) should be not null."); PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")),
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Y")), "Input(Y@GRAD) shoudl be not null.");
"Input(Y@GRAD) shoudl be not null."); PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar(framework::GradVarName("X")), "Output(X@GRAD) should be not null.");
"Output(X@GRAD) should be not null.");
auto x_dims = ctx->GetInputDim("X");
auto x = ctx.Input<Tensor>("X"); auto label_dims = ctx->GetInputDim("Label");
auto label = ctx.Input<Tensor>("Label"); auto dy_dims = ctx->GetInputDim(framework::GradVarName("Y"));
auto dy = ctx.Input<Tensor>(framework::GradVarName("Y")); PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2.");
PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank should be 2."); PADDLE_ENFORCE_EQ(dy_dims.size(), 2, "Input(Y@Grad)'s rank should be 2.");
PADDLE_ENFORCE_EQ(dy->dims().size(), 2, PADDLE_ENFORCE_EQ(label_dims.size(), 2, "Input(Label)'s rank should be 2.");
"Input(Y@Grad)'s rank should be 2."); PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0],
PADDLE_ENFORCE_EQ(label->dims().size(), 2,
"Input(Label)'s rank should be 2.");
PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0],
"The 1st dimension of Input(X) and Input(Label) should " "The 1st dimension of Input(X) and Input(Label) should "
"be equal."); "be equal.");
PADDLE_ENFORCE_EQ(x->dims()[0], dy->dims()[0], PADDLE_ENFORCE_EQ(x_dims[0], dy_dims[0],
"The 1st dimension of Input(X) and Input(Y@Grad) should " "The 1st dimension of Input(X) and Input(Y@Grad) should "
"be equal."); "be equal.");
PADDLE_ENFORCE_EQ(dy->dims()[1], 1, PADDLE_ENFORCE_EQ(dy_dims[1], 1,
"The 2nd dimension of Input(Y@Grad) should be 1."); "The 2nd dimension of Input(Y@Grad) should be 1.");
if (ctx.Attr<bool>("softLabel")) { if (ctx->Attrs().Get<bool>("softLabel")) {
PADDLE_ENFORCE_EQ(x->dims()[1], label->dims()[1], PADDLE_ENFORCE_EQ(x_dims[1], label_dims[1],
"When Attr(softLabel) == true, the 2nd dimension of " "When Attr(softLabel) == true, the 2nd dimension of "
"Input(X) and Input(Label) should be equal."); "Input(X) and Input(Label) should be equal.");
} else { } else {
PADDLE_ENFORCE_EQ(label->dims()[1], 1, PADDLE_ENFORCE_EQ(label_dims[1], 1,
"When Attr(softLabel) == false, the 2nd dimension of " "When Attr(softLabel) == false, the 2nd dimension of "
"Input(Label) should be 1."); "Input(Label) should be 1.");
} }
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
auto dx = ctx.Output<Tensor>(framework::GradVarName("X"));
dx->Resize(x->dims());
} }
}; };
class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker { class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
CrossEntropyOpMaker(framework::OpProto *proto, CrossEntropyOpMaker(framework::OpProto* proto,
framework::OpAttrChecker *op_checker) framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(Tensor, default Tensor<float>), a 2-D tensor with shape N x D, " "(Tensor, default Tensor<float>), a 2-D tensor with shape N x D, "
......
...@@ -24,25 +24,25 @@ class DropoutOp : public framework::OperatorWithKernel { ...@@ -24,25 +24,25 @@ class DropoutOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null."); PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
PADDLE_ENFORCE_GE(ctx.Attr<float>("dropout_prob"), 0); PADDLE_ENFORCE_GE(ctx->Attrs().Get<float>("dropout_prob"), 0);
PADDLE_ENFORCE_LE(ctx.Attr<float>("dropout_prob"), 1); PADDLE_ENFORCE_LE(ctx->Attrs().Get<float>("dropout_prob"), 1);
auto dims = ctx.Input<Tensor>("X")->dims(); auto x_dims = ctx->GetInputDim("X");
ctx.Output<Tensor>("Out")->Resize(dims); ctx->SetOutputDim("Out", x_dims);
if (ctx.Attr<bool>("is_training")) { if (ctx->Attrs().Get<bool>("is_training") == 1) {
ctx.Output<Tensor>("Mask")->Resize(dims); ctx->SetOutputDim("Mask", x_dims);
} }
ctx.ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
}; };
template <typename AttrType> template <typename AttrType>
class DropoutOpMaker : public framework::OpProtoAndCheckerMaker { class DropoutOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
DropoutOpMaker(framework::OpProto *proto, DropoutOpMaker(framework::OpProto* proto,
framework::OpAttrChecker *op_checker) framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddAttr<AttrType>("dropout_prob", "Probability of setting units to zero.") AddAttr<AttrType>("dropout_prob", "Probability of setting units to zero.")
.SetDefault(.5f); .SetDefault(.5f);
...@@ -70,27 +70,26 @@ class DropoutOpGrad : public framework::OperatorWithKernel { ...@@ -70,27 +70,26 @@ class DropoutOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE(ctx.Attr<bool>("is_training"), PADDLE_ENFORCE_EQ(ctx->Attrs().Get<bool>("is_training"), 1,
"GradOp is only callable when is_training is true"); "GradOp is only callable when is_training is true");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null."); PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Mask"), "Mask must not be null."); PADDLE_ENFORCE(ctx->HasInput("Mask"), "Mask must not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) must not be null."); "Input(Out@GRAD) must not be null.");
PADDLE_ENFORCE_GE(ctx.Attr<AttrType>("dropout_prob"), 0); PADDLE_ENFORCE_GE(ctx->Attrs().Get<AttrType>("dropout_prob"), 0);
PADDLE_ENFORCE_LE(ctx.Attr<AttrType>("dropout_prob"), 1); PADDLE_ENFORCE_LE(ctx->Attrs().Get<AttrType>("dropout_prob"), 1);
auto x_dims = ctx.Input<Tensor>("X")->dims(); auto x_dims = ctx->GetInputDim("X");
auto out_dims = ctx.Input<Tensor>(framework::GradVarName("Out"))->dims(); auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
PADDLE_ENFORCE_EQ(x_dims, out_dims, PADDLE_ENFORCE_EQ(x_dims, out_dims,
"Dimensions of Input(X) and Out@Grad must be the same."); "Dimensions of Input(X) and Out@Grad must be the same.");
auto mask_dims = ctx.Input<Tensor>("Mask")->dims(); auto mask_dims = ctx->GetInputDim("Mask");
PADDLE_ENFORCE_EQ(x_dims, mask_dims, PADDLE_ENFORCE_EQ(x_dims, mask_dims,
"Dimensions of Input(X) and Mask must be the same."); "Dimensions of Input(X) and Mask must be the same.");
auto *x_grad = ctx.Output<Tensor>(framework::GradVarName("X")); ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
x_grad->Resize(x_dims);
} }
}; };
......
...@@ -202,21 +202,20 @@ class ElementwiseOp : public framework::OperatorWithKernel { ...@@ -202,21 +202,20 @@ class ElementwiseOp : public framework::OperatorWithKernel {
protected: protected:
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
void InferShape(const framework::InferShapeContext& ctx) const override { void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of elementwise op should not be null"); "Input(X) of elementwise op should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), PADDLE_ENFORCE(ctx->HasInput("Y"),
"Input(Y) of elementwise op should not be null"); "Input(Y) of elementwise op should not be null");
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE(ctx->HasOutput("Out"),
ctx.OutputVar("Out"), "Output(Out) of elementwise op should not be null.");
"Output(Out) of elementwise op should not be null.");
auto x_dim = ctx->GetInputDim("X");
auto x_dim = ctx.Input<Tensor>("X")->dims(); auto y_dim = ctx->GetInputDim("Y");
auto y_dim = ctx.Input<Tensor>("Y")->dims();
PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(), PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(),
"Rank of first input must >= rank of second input.") "Rank of first input must >= rank of second input.")
ctx.Output<framework::Tensor>("Out")->Resize(x_dim); ctx->SetOutputDim("Out", x_dim);
ctx.ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
}; };
...@@ -234,7 +233,7 @@ must be small or equal to X's dimensions. ...@@ -234,7 +233,7 @@ must be small or equal to X's dimensions.
)DOC"); )DOC");
AddAttr<int>("axis", AddAttr<int>("axis",
R"DOC( R"DOC(
When the shape(Y) does not equal the shape(X),Y will be broadcasted When the shape(Y) does not equal the shape(X),Y will be broadcasted
to match the shape of X and axis should be dimension index Y in X to match the shape of X and axis should be dimension index Y in X
)DOC") )DOC")
.SetDefault(-1) .SetDefault(-1)
...@@ -244,7 +243,7 @@ to match the shape of X and axis should be dimension index Y in X ...@@ -244,7 +243,7 @@ to match the shape of X and axis should be dimension index Y in X
comment_ = R"DOC( comment_ = R"DOC(
Limited elementwise {name} operator.The equation is: Out = {equation}. Limited elementwise {name} operator.The equation is: Out = {equation}.
1. The shape of Y should be same with X or 1. The shape of Y should be same with X or
2. Y's shape is a subset of X. 2. Y's shape is a subset of X.
Y will be broadcasted to match the shape of X and axis should be dimension index Y in X. Y will be broadcasted to match the shape of X and axis should be dimension index Y in X.
example: example:
...@@ -284,27 +283,26 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel { ...@@ -284,27 +283,26 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
protected: protected:
void InferShape(const framework::InferShapeContext& ctx) const override { void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null"); PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) should not be null"); PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null"); "Input(Out@GRAD) should not be null");
auto x_dims = ctx.Input<Tensor>("X")->dims(); auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx.Input<Tensor>("Y")->dims(); auto y_dims = ctx->GetInputDim("Y");
auto out_dims = ctx.Input<Tensor>(framework::GradVarName("Out"))->dims(); auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
auto* x_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto* y_grad = ctx.Output<framework::Tensor>(framework::GradVarName("Y"));
PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(), PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(),
"Rank of first input must >= rank of second input.") "Rank of first input must >= rank of second input.")
if (x_grad) { auto x_grad_name = framework::GradVarName("X");
x_grad->Resize(x_dims); auto y_grad_name = framework::GradVarName("Y");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
} }
if (ctx->HasOutput(y_grad_name)) {
if (y_grad) { ctx->SetOutputDim(y_grad_name, y_dims);
y_grad->Resize(y_dims);
} }
} }
}; };
......
...@@ -22,15 +22,13 @@ class FillZerosLikeOp : public framework::OperatorWithKernel { ...@@ -22,15 +22,13 @@ class FillZerosLikeOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(framework::InferShapeContextBase *ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of FillZerosLikeOp should not be null."); "Input(X) of FillZerosLikeOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Y"), PADDLE_ENFORCE(ctx->HasOutput("Y"),
"Output(Y) of FillZerosLikeOp should not be null."); "Output(Y) of FillZerosLikeOp should not be null.");
ctx->SetOutputDim("Y", ctx->GetInputDim("X"));
ctx.Output<framework::Tensor>("Y")->Resize( ctx->ShareLoD("X", /*->*/ "Y");
ctx.Input<framework::Tensor>("X")->dims());
ctx.ShareLoD("X", /*->*/ "Y");
} }
}; };
......
...@@ -23,19 +23,19 @@ class GatherOp : public framework::OperatorWithKernel { ...@@ -23,19 +23,19 @@ class GatherOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of GatherOp should not be null."); "Input(X) of GatherOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Index"), PADDLE_ENFORCE(ctx->HasInput("Index"),
"Input(Index) of GatherOp should not be null."); "Input(Index) of GatherOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of GatherOp should not be null."); "Output(Out) of GatherOp should not be null.");
int batch_size = ctx.Input<Tensor>("Index")->dims()[0]; int batch_size = ctx->GetInputDim("Index")[0];
PADDLE_ENFORCE_GE(batch_size, 0, "Batch size must be >0"); PADDLE_ENFORCE_GE(batch_size, 0, "Batch size must be >0");
framework::DDim output_dims(ctx.Input<Tensor>("X")->dims()); framework::DDim output_dims(ctx->GetInputDim("X"));
output_dims[0] = batch_size; output_dims[0] = batch_size;
ctx.Output<framework::Tensor>("Out")->Resize(output_dims); ctx->SetOutputDim("Out", output_dims);
} }
}; };
...@@ -44,23 +44,20 @@ class GatherGradOp : public framework::OperatorWithKernel { ...@@ -44,23 +44,20 @@ class GatherGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(framework::InferShapeContextBase* ctx) const override {
auto X_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X")); ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
auto X = ctx.Input<Tensor>("X");
X_grad->Resize(X->dims());
} }
}; };
class GatherOpMaker : public framework::OpProtoAndCheckerMaker { class GatherOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
GatherOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) GatherOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The source input of gather op"); AddInput("X", "The source input of gather op");
AddInput("Index", "The index input of gather op"); AddInput("Index", "The index input of gather op");
AddOutput("Out", "The output of add op"); AddOutput("Out", "The output of add op");
AddComment(R"DOC( AddComment(R"DOC(
Gather Operator by selecting from the first axis, Gather Operator by selecting from the first axis,
Out = X[Index] Out = X[Index]
)DOC"); )DOC");
......
...@@ -43,13 +43,10 @@ class GaussianRandomOp : public framework::OperatorWithKernel { ...@@ -43,13 +43,10 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext& ctx) const override { void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE(ctx->HasOutput("Out"),
ctx.OutputVar("Out"), "Output(Out) of GaussianRandomOp should not be null.");
"Output(Out) of GaussianRandomOp should not be null."); auto dims = ctx->Attrs().Get<std::vector<int>>("dims");
auto* tensor = ctx.Output<framework::Tensor>("Out");
auto dims = Attr<std::vector<int>>("dims");
std::vector<int64_t> temp; std::vector<int64_t> temp;
temp.reserve(dims.size()); temp.reserve(dims.size());
for (auto dim : dims) { for (auto dim : dims) {
...@@ -57,7 +54,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel { ...@@ -57,7 +54,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
} }
PADDLE_ENFORCE(dims.size() > 0UL, PADDLE_ENFORCE(dims.size() > 0UL,
"dims can be one int or array. dims must be set."); "dims can be one int or array. dims must be set.");
tensor->Resize(framework::make_ddim(temp)); ctx->SetOutputDim("Out", framework::make_ddim(temp));
} }
}; };
......
...@@ -22,27 +22,26 @@ class LookupTableOp : public framework::OperatorWithKernel { ...@@ -22,27 +22,26 @@ class LookupTableOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("W"), PADDLE_ENFORCE(ctx->HasInput("W"),
"Input(W) of LookupTableOp should not be null."); "Input(W) of LookupTableOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Ids"), PADDLE_ENFORCE(ctx->HasInput("Ids"),
"Input(Ids) of LookupTableOp should not be null."); "Input(Ids) of LookupTableOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of LookupTableOp should not be null."); "Output(Out) of LookupTableOp should not be null.");
auto table_t = ctx.Input<Tensor>("W"); auto table_dims = ctx->GetInputDim("W");
auto ids_t = ctx.Input<Tensor>("Ids"); auto ids_dims = ctx->GetInputDim("Ids");
auto output_t = ctx.Output<framework::Tensor>("Out");
ctx->SetOutputDim("Out", {ids_dims[0], table_dims[1]});
output_t->Resize({ids_t->dims()[0], table_t->dims()[1]}); ctx->ShareLoD("Ids", /*->*/ "Out");
ctx.ShareLoD("Ids", /*->*/ "Out");
} }
}; };
class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker { class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
LookupTableOpMaker(framework::OpProto *proto, LookupTableOpMaker(framework::OpProto* proto,
framework::OpAttrChecker *op_checker) framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("W", AddInput("W",
"An input represents embedding tensors," "An input represents embedding tensors,"
...@@ -66,11 +65,9 @@ class LookupTableOpGrad : public framework::OperatorWithKernel { ...@@ -66,11 +65,9 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &context) const override { void InferShape(framework::InferShapeContextBase* ctx) const override {
auto table = context.Input<Tensor>("W"); auto table_dims = ctx->GetInputDim("W");
auto d_table = ctx->SetOutputDim(framework::GradVarName("W"), table_dims);
context.Output<framework::Tensor>(framework::GradVarName("W"));
d_table->Resize(table->dims());
} }
}; };
......
...@@ -22,37 +22,36 @@ class LstmUnitOp : public framework::OperatorWithKernel { ...@@ -22,37 +22,36 @@ class LstmUnitOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of LSTM should not be null.");
"Input(X) of LSTM should not be null."); PADDLE_ENFORCE(ctx->HasInput("C_prev"),
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("C_prev"), "Input(C_prev) of LSTM should not be null.");
"Input(C_prev) of LSTM should not be null."); PADDLE_ENFORCE(ctx->HasOutput("C"),
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("C"), "Output(C) of LSTM should not be null.");
"Output(C) of LSTM should not be null."); PADDLE_ENFORCE(ctx->HasOutput("H"),
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("H"), "Output(H) of LSTM should not be null.");
"Output(H) of LSTM should not be null.");
auto x_dims = ctx->GetInputDim("X");
auto *x = ctx.Input<framework::Tensor>("X"); auto c_prev_dims = ctx->GetInputDim("C_prev");
auto *c_prev = ctx.Input<framework::Tensor>("C_prev");
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2.");
PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank must be 2."); PADDLE_ENFORCE(x_dims[0] == c_prev_dims[0],
PADDLE_ENFORCE(x->dims()[0] == c_prev->dims()[0],
"Batch size of inputs and states must be equal"); "Batch size of inputs and states must be equal");
PADDLE_ENFORCE(x->dims()[1] == c_prev->dims()[1] * 4, PADDLE_ENFORCE(x_dims[1] == c_prev_dims[1] * 4,
"Dimension of FC should equal to prev state * 4"); "Dimension of FC should equal to prev state * 4");
int b_size = c_prev->dims()[0]; // batch size int b_size = c_prev_dims[0]; // batch size
int s_dim = c_prev->dims()[1]; // state dim int s_dim = c_prev_dims[1]; // state dim
ctx.Output<framework::LoDTensor>("C")->Resize({b_size, s_dim}); ctx->SetOutputDim("C", {b_size, s_dim});
ctx.Output<framework::LoDTensor>("H")->Resize({b_size, s_dim}); ctx->SetOutputDim("H", {b_size, s_dim});
} }
}; };
template <typename AttrType> template <typename AttrType>
class LstmUnitOpMaker : public framework::OpProtoAndCheckerMaker { class LstmUnitOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
LstmUnitOpMaker(framework::OpProto *proto, LstmUnitOpMaker(framework::OpProto* proto,
framework::OpAttrChecker *op_checker) framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "FC input before the non-linear activation."); AddInput("X", "FC input before the non-linear activation.");
AddInput( AddInput(
...@@ -63,11 +62,11 @@ class LstmUnitOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -63,11 +62,11 @@ class LstmUnitOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC(Lstm-Unit Operator AddComment(R"DOC(Lstm-Unit Operator
Equation: Equation:
i, f, o, j = split(X) i, f, o, j = split(X)
C = C_prev * sigm(f + forget_bias) + sigm(i) * tanh(j) C = C_prev * sigm(f + forget_bias) + sigm(i) * tanh(j)
H = C * sigm(o) H = C * sigm(o)
)DOC"); )DOC");
AddAttr<AttrType>("forget_bias", "The forget bias of Lstm Unit.") AddAttr<AttrType>("forget_bias", "The forget bias of Lstm Unit.")
.SetDefault(0.0); .SetDefault(0.0);
...@@ -79,15 +78,14 @@ class LstmUnitGradOp : public framework::OperatorWithKernel { ...@@ -79,15 +78,14 @@ class LstmUnitGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("C")), PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("C")),
"Input(C@GRAD) should not be null"); "Input(C@GRAD) should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("H")), PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("H")),
"Input(H@GRAD) should not be null"); "Input(H@GRAD) should not be null");
ctx.Output<framework::LoDTensor>(framework::GradVarName("X")) ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
->Resize(ctx.Input<Tensor>("X")->dims()); ctx->SetOutputDim(framework::GradVarName("C_prev"),
ctx.Output<framework::LoDTensor>(framework::GradVarName("C_prev")) ctx->GetInputDim("C_prev"));
->Resize(ctx.Input<Tensor>("C_prev")->dims());
} }
}; };
......
...@@ -22,18 +22,18 @@ class MeanOp : public framework::OperatorWithKernel { ...@@ -22,18 +22,18 @@ class MeanOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of MeanOp should not be null."); "Input(X) of MeanOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of MeanOp should not be null."); "Output(Out) of MeanOp should not be null.");
ctx.Output<framework::Tensor>("Out")->Resize({1}); ctx->SetOutputDim("Out", {1});
} }
}; };
class MeanOpMaker : public framework::OpProtoAndCheckerMaker { class MeanOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
MeanOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) MeanOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of mean op"); AddInput("X", "The input of mean op");
AddOutput("Out", "The output of mean op").NotInGradient(); AddOutput("Out", "The output of mean op").NotInGradient();
...@@ -47,9 +47,8 @@ class MeanGradOp : public framework::OperatorWithKernel { ...@@ -47,9 +47,8 @@ class MeanGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(framework::InferShapeContextBase* ctx) const override {
ctx.Output<framework::Tensor>(framework::GradVarName("X")) ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
->Resize(ctx.Input<Tensor>("X")->dims());
} }
}; };
......
...@@ -26,22 +26,22 @@ class MinusOp : public framework::OperatorWithKernel { ...@@ -26,22 +26,22 @@ class MinusOp : public framework::OperatorWithKernel {
: OperatorWithKernel(type, inputs, outputs, attrs) {} : OperatorWithKernel(type, inputs, outputs, attrs) {}
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(framework::InferShapeContextBase *ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of MinusOp should not be null."); "Input(X) of MinusOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), PADDLE_ENFORCE(ctx->HasInput("Y"),
"Input(Y) of MinusOp should not be null."); "Input(Y) of MinusOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of MinusOp should not be null."); "Output(Out) of MinusOp should not be null.");
auto *left_tensor = ctx.Input<framework::Tensor>("X"); auto x_dims = ctx->GetInputDim("X");
auto *right_tensor = ctx.Input<framework::Tensor>("Y"); auto y_dims = ctx->GetInputDim("Y");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
left_tensor->numel(), right_tensor->numel(), x_dims, y_dims,
"Minus operator must take two tensor with same num of elements"); "Minus operator must take two tensor with same num of elements");
ctx.Output<framework::Tensor>("Out")->Resize(left_tensor->dims()); ctx->SetOutputDim("Out", x_dims);
ctx.ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
}; };
......
...@@ -22,20 +22,19 @@ class ModifiedHuberLossOp : public framework::OperatorWithKernel { ...@@ -22,20 +22,19 @@ class ModifiedHuberLossOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext& context) const override { void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE_NOT_NULL(context.InputVar("X"), "X must be initialized."); PADDLE_ENFORCE(ctx->HasInput("X"), "X must be initialized.");
PADDLE_ENFORCE_NOT_NULL(context.InputVar("Y"), "Y must be initialized."); PADDLE_ENFORCE(ctx->HasInput("Y"), "Y must be initialized.");
auto* x = context.Input<Tensor>("X"); auto x_dims = ctx->GetInputDim("X");
auto* y = context.Input<Tensor>("Y"); auto y_dims = ctx->GetInputDim("Y");
PADDLE_ENFORCE_EQ(x->dims(), y->dims(), PADDLE_ENFORCE_EQ(x_dims, y_dims, "The shape of X and Y must be the same.");
"The shape of X and Y must be the same."); PADDLE_ENFORCE_EQ(x_dims.size(), 2, "The tensor rank of X must be 2.");
PADDLE_ENFORCE_EQ(x->dims().size(), 2, "The tensor rank of X must be 2."); PADDLE_ENFORCE_EQ(x_dims[1], 1, "The 2nd dimension of X must be 1.");
PADDLE_ENFORCE_EQ(x->dims()[1], 1, "The 2nd dimension of X must be 1.");
context.Output<framework::Tensor>("IntermediateVal")->Resize(x->dims()); ctx->SetOutputDim("IntermediateVal", x_dims);
context.Output<framework::Tensor>("Out")->Resize({x->dims()[0], 1}); ctx->SetOutputDim("Out", {x_dims[0], 1});
} }
}; };
...@@ -75,27 +74,28 @@ class ModifiedHuberLossGradOp : public framework::OperatorWithKernel { ...@@ -75,27 +74,28 @@ class ModifiedHuberLossGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext& context) const override { void InferShape(framework::InferShapeContextBase* ctx) const override {
auto* x = context.Input<Tensor>("X"); PADDLE_ENFORCE(ctx->HasInput("X"), "X must be initialized.");
auto* y = context.Input<Tensor>("Y"); PADDLE_ENFORCE(ctx->HasInput("Y"), "Y must be initialized.");
auto* intermediate_val = context.Input<Tensor>("IntermediateVal"); PADDLE_ENFORCE(ctx->HasInput("IntermediateVal"),
auto* out_grad = context.Input<Tensor>(framework::GradVarName("Out")); "Intermediate value must not be null.");
auto* x_grad = PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
context.Output<framework::Tensor>(framework::GradVarName("X")); "Input(Out@Grad) must not be null.");
PADDLE_ENFORCE_NOT_NULL(x, "X must be initialized."); auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_NOT_NULL(y, "Y must be initialized."); auto y_dims = ctx->GetInputDim("Y");
PADDLE_ENFORCE_NOT_NULL(intermediate_val, auto intermediate_dims = ctx->GetInputDim("IntermediateVal");
"Intermediate value must not be null."); auto out_grad_dims = ctx->GetInputDim(framework::GradVarName("Out"));
PADDLE_ENFORCE_NOT_NULL(out_grad, "Input(Out@Grad) must not be null.");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
intermediate_val->dims(), x->dims(), intermediate_dims, x_dims,
"The shape of X and intermediate value must be the same."); "The shape of X and intermediate value must be the same.");
PADDLE_ENFORCE_EQ(out_grad->dims(), x->dims(), PADDLE_ENFORCE_EQ(out_grad_dims, x_dims,
"The shape of Input(Out@Grad) and X must be the same."); "The shape of Input(Out@Grad) and X must be the same.");
if (x_grad) x_grad->Resize(x->dims()); if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
}
} }
}; };
......
...@@ -24,27 +24,23 @@ class MulOp : public framework::OperatorWithKernel { ...@@ -24,27 +24,23 @@ class MulOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of MulOp should not be null.");
"Input(X) of MulOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) of MulOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Input(Y) of MulOp should not be null."); "Output(Out) of MulOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"),
"Output(Out) of MulOp should not be null."); auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
auto x_dims = ctx.Input<Tensor>("X")->dims(); int x_num_col_dims = ctx->Attrs().Get<int>("x_num_col_dims");
auto y_dims = ctx.Input<Tensor>("Y")->dims(); int y_num_col_dims = ctx->Attrs().Get<int>("y_num_col_dims");
int x_num_col_dims = Attr<int>("x_num_col_dims");
int y_num_col_dims = Attr<int>("y_num_col_dims");
PADDLE_ENFORCE(x_dims.size() > x_num_col_dims, PADDLE_ENFORCE(x_dims.size() > x_num_col_dims,
"The rank of input tensor X(%s) should be larger than " "The rank of input tensor X should be larger than "
"`mul_op`'s `x_num_col_dims`.", "`mul_op`'s `x_num_col_dims`.");
ctx.op().Input("X"));
PADDLE_ENFORCE(y_dims.size() > y_num_col_dims, PADDLE_ENFORCE(y_dims.size() > y_num_col_dims,
"The rank of input tensor Y(%s) should be larger than " "The rank of input tensor Y should be larger than "
"`mul_op`'s `y_num_col_dims`.", "`mul_op`'s `y_num_col_dims`.");
ctx.op().Input("Y"));
auto x_mat_dims = framework::flatten_to_2d(x_dims, x_num_col_dims); auto x_mat_dims = framework::flatten_to_2d(x_dims, x_num_col_dims);
auto y_mat_dims = framework::flatten_to_2d(y_dims, y_num_col_dims); auto y_mat_dims = framework::flatten_to_2d(y_dims, y_num_col_dims);
...@@ -52,24 +48,23 @@ class MulOp : public framework::OperatorWithKernel { ...@@ -52,24 +48,23 @@ class MulOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
x_mat_dims[1], y_mat_dims[0], x_mat_dims[1], y_mat_dims[0],
"First matrix's width must be equal with second matrix's height."); "First matrix's width must be equal with second matrix's height.");
ctx.Output<framework::Tensor>("Out")->Resize( ctx->SetOutputDim("Out", {x_mat_dims[0], y_mat_dims[1]});
{x_mat_dims[0], y_mat_dims[1]}); ctx->ShareLoD("X", /*->*/ "Out");
ctx.ShareLoD("X", /*->*/ "Out");
} }
}; };
class MulOpMaker : public framework::OpProtoAndCheckerMaker { class MulOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
MulOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) MulOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The first input of mul op"); AddInput("X", "The first input of mul op");
AddInput("Y", "The second input of mul op"); AddInput("Y", "The second input of mul op");
AddOutput("Out", "The output of mul op"); AddOutput("Out", "The output of mul op");
AddAttr<int>( AddAttr<int>(
"x_num_col_dims", "x_num_col_dims",
R"DOC(mul_op can take tensors with more than two dimensions as input `X`, R"DOC(mul_op can take tensors with more than two dimensions as input `X`,
in that case, tensors will be reshaped to a matrix. The matrix's first in that case, tensors will be reshaped to a matrix. The matrix's first
dimension(column length) will be the product of tensor's last dimension(column length) will be the product of tensor's last
`num_col_dims` dimensions, and the matrix's second dimension(row length) `num_col_dims` dimensions, and the matrix's second dimension(row length)
will be the product of tensor's first `rank - num_col_dims` dimensions. will be the product of tensor's first `rank - num_col_dims` dimensions.
)DOC") )DOC")
...@@ -100,16 +95,14 @@ class MulOpGrad : public framework::OperatorWithKernel { ...@@ -100,16 +95,14 @@ class MulOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null"); PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) should not be null"); PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null"); "Input(Out@GRAD) should not be null");
auto x_dims = ctx.Input<Tensor>("X")->dims(); auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx.Input<Tensor>("Y")->dims(); auto y_dims = ctx->GetInputDim("Y");
auto out_dims = ctx.Input<Tensor>(framework::GradVarName("Out"))->dims(); auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
auto *x_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto *y_grad = ctx.Output<framework::Tensor>(framework::GradVarName("Y"));
auto x_mat_dims = auto x_mat_dims =
framework::flatten_to_2d(x_dims, Attr<int>("x_num_col_dims")); framework::flatten_to_2d(x_dims, Attr<int>("x_num_col_dims"));
...@@ -125,8 +118,15 @@ class MulOpGrad : public framework::OperatorWithKernel { ...@@ -125,8 +118,15 @@ class MulOpGrad : public framework::OperatorWithKernel {
"The second dimension of Out@GRAD must equal to the second " "The second dimension of Out@GRAD must equal to the second "
"dimension of the second operand."); "dimension of the second operand.");
if (x_grad) x_grad->Resize(x_dims); auto x_grad_name = framework::GradVarName("X");
if (y_grad) y_grad->Resize(y_dims); auto y_grad_name = framework::GradVarName("Y");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
}
if (ctx->HasOutput(y_grad_name)) {
ctx->SetOutputDim(y_grad_name, y_dims);
}
} }
}; };
......
...@@ -24,41 +24,38 @@ class MultiplexOp : public framework::OperatorWithKernel { ...@@ -24,41 +24,38 @@ class MultiplexOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Ids"), PADDLE_ENFORCE(ctx->HasInput("Ids"), "Input(Ids) shouldn't be null.");
"Input(Ids) shouldn't be null."); PADDLE_ENFORCE(!ctx->Inputs("X").empty(),
PADDLE_ENFORCE(!ctx.MultiInputVar("X").empty(),
"MultiInput(X) shouldn't be empty."); "MultiInput(X) shouldn't be empty.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) shouldn't be null.");
"Output(Out) shouldn't be null."); auto ids_dim = ctx->GetInputDim("Ids");
auto ids_dim = ctx.Input<Tensor>("Ids")->dims();
PADDLE_ENFORCE( PADDLE_ENFORCE(
ids_dim.size() == 2 && ids_dim[1] == 1, ids_dim.size() == 2 && ids_dim[1] == 1,
"The index tensor must be a vector with size batchSize x 1."); "The index tensor must be a vector with size batchSize x 1.");
auto ins = ctx.MultiInput<Tensor>("X"); auto ins_dims = ctx->GetInputsDim("X");
auto *out = ctx.Output<Tensor>("Out"); auto num_ins = ins_dims.size();
auto num_ins = ins.size();
PADDLE_ENFORCE(num_ins > 1, PADDLE_ENFORCE(num_ins > 1,
"multiplex operator should have more than " "multiplex operator should have more than "
"one candidate input tensors."); "one candidate input tensors.");
auto in_dim = ins[0]->dims(); auto in_dim = ins_dims[0];
PADDLE_ENFORCE(in_dim.size() >= 2, PADDLE_ENFORCE(in_dim.size() >= 2,
"The rank of candidate tensors must be not less than 2."); "The rank of candidate tensors must be not less than 2.");
for (size_t i = 1; i < num_ins; i++) { for (size_t i = 1; i < num_ins; i++) {
auto dim = ins[i]->dims(); auto dim = ins_dims[i];
PADDLE_ENFORCE(in_dim == dim, PADDLE_ENFORCE(in_dim == dim,
"All the candidate tensors must have the same size."); "All the candidate tensors must have the same size.");
} }
out->Resize(in_dim); ctx->SetOutputDim("Out", in_dim);
} }
}; };
class MultiplexOpMaker : public framework::OpProtoAndCheckerMaker { class MultiplexOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
MultiplexOpMaker(framework::OpProto *proto, MultiplexOpMaker(framework::OpProto* proto,
framework::OpAttrChecker *op_checker) framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Ids", "The index tensor of multiplex operator."); AddInput("Ids", "The index tensor of multiplex operator.");
AddInput("X", "The candidate tensors of multiplex operator.") AddInput("X", "The candidate tensors of multiplex operator.")
...@@ -88,21 +85,19 @@ class MultiplexGradOp : public framework::OperatorWithKernel { ...@@ -88,21 +85,19 @@ class MultiplexGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE(!ctx.MultiInputVar("X").empty(), PADDLE_ENFORCE(!ctx->Inputs("X").empty(), "Input(X) should not be null.");
"Input(X) should not be null."); PADDLE_ENFORCE(!ctx->Outputs(framework::GradVarName("X")).empty(),
PADDLE_ENFORCE(!ctx.MultiOutputVar(framework::GradVarName("X")).empty(),
"Output(X@Grad) should not be null."); "Output(X@Grad) should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null."); "Input(Out@GRAD) should not be null.");
auto d_ins = ctx.MultiOutput<Tensor>(framework::GradVarName("X")); std::vector<framework::DDim> d_ins;
auto ins = ctx.MultiInput<Tensor>("X"); auto ins = ctx->GetInputsDim("X");
// No need to compute gradient for Input(Ids) // No need to compute gradient for Input(Ids)
for (size_t i = 0; i < ins.size(); i++) { for (size_t i = 0; i < ins.size(); i++) {
if (d_ins[i]) { d_ins.push_back(ins[i]);
d_ins[i]->Resize(ins[i]->dims());
}
} }
ctx->SetOutputsDim(framework::GradVarName("X"), d_ins);
} }
}; };
......
...@@ -24,14 +24,13 @@ class PadOp : public framework::OperatorWithKernel { ...@@ -24,14 +24,13 @@ class PadOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of PadOp should not be null.");
"Input(X) of PadOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"),
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), "Output(Out) of PadOp should not be null.");
"Output(Out) of PadOp should not be null.");
auto x_dim = ctx->GetInputDim("X");
auto x_dim = ctx.Input<Tensor>("X")->dims(); auto paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
auto paddings = Attr<std::vector<int>>("paddings");
PADDLE_ENFORCE_EQ(x_dim.size() * 2, int64_t(paddings.size()), PADDLE_ENFORCE_EQ(x_dim.size() * 2, int64_t(paddings.size()),
"Size of paddings should be equal to 2 * dimension size " "Size of paddings should be equal to 2 * dimension size "
"of input tensor."); "of input tensor.");
...@@ -39,19 +38,18 @@ class PadOp : public framework::OperatorWithKernel { ...@@ -39,19 +38,18 @@ class PadOp : public framework::OperatorWithKernel {
for (int i = 0; i < x_dim.size(); ++i) { for (int i = 0; i < x_dim.size(); ++i) {
out_dims[i] = x_dim[i] + paddings[i * 2] + paddings[i * 2 + 1]; out_dims[i] = x_dim[i] + paddings[i * 2] + paddings[i * 2 + 1];
} }
ctx.Output<framework::Tensor>("Out")->Resize( ctx->SetOutputDim("Out", framework::make_ddim(out_dims));
framework::make_ddim(out_dims));
if (out_dims[0] == x_dim[0]) { if (out_dims[0] == x_dim[0]) {
// Only pass LoD when the first dimension is equal between // Only pass LoD when the first dimension is equal between
// output and input. // output and input.
ctx.ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
} }
}; };
class PadOpMaker : public framework::OpProtoAndCheckerMaker { class PadOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
PadOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) PadOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"The input of pad op. " "The input of pad op. "
...@@ -68,15 +66,15 @@ Given: ...@@ -68,15 +66,15 @@ Given:
X = [[1, 2], X = [[1, 2],
[3, 4]] [3, 4]]
and and
paddings = [0, 1, 1, 2] paddings = [0, 1, 1, 2]
and and
pad_value = 0
then we get pad_value = 0
then we get
Out = [[0, 1, 2, 0, 0] Out = [[0, 1, 2, 0, 0]
[0, 3, 4, 0, 0] [0, 3, 4, 0, 0]
...@@ -101,14 +99,14 @@ class PadOpGrad : public framework::OperatorWithKernel { ...@@ -101,14 +99,14 @@ class PadOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null"); PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null"); "Input(Out@GRAD) should not be null");
auto x_dims = ctx.Input<Tensor>("X")->dims(); auto x_dims = ctx->GetInputDim("X");
auto *x_g = ctx.Output<framework::Tensor>(framework::GradVarName("X")); auto x_grad_name = framework::GradVarName("X");
if (x_g != nullptr) { if (ctx->HasOutput(x_grad_name)) {
x_g->Resize(x_dims); ctx->SetOutputDim(x_grad_name, x_dims);
} }
} }
}; };
......
...@@ -26,19 +26,14 @@ class PReluOp : public framework::OperatorWithKernel { ...@@ -26,19 +26,14 @@ class PReluOp : public framework::OperatorWithKernel {
: OperatorWithKernel(type, inputs, outputs, attrs) {} : OperatorWithKernel(type, inputs, outputs, attrs) {}
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(framework::InferShapeContextBase *ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null"); PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
auto *in = ctx.Input<framework::Tensor>("X"); PADDLE_ENFORCE(ctx->HasInput("Alpha"), "Input(Alpha) should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Alpha"), PADDLE_ENFORCE(product(ctx->GetInputDim("Alpha")) == 1,
"Input(Alpha) should not be null"); "Size of weight Alpha must be one.");
auto *alpha = ctx.Input<framework::Tensor>("Alpha"); PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null");
PADDLE_ENFORCE(alpha->numel() == 1, "Size of weight Alpha must be one."); ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ "Out");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"),
"Output(Out) should not be null");
auto *out = ctx.Output<framework::Tensor>("Out");
out->Resize(in->dims());
ctx.ShareLoD("X", /*->*/ "Out");
} }
}; };
...@@ -68,19 +63,13 @@ class PReluGradOp : public framework::OperatorWithKernel { ...@@ -68,19 +63,13 @@ class PReluGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(framework::InferShapeContextBase *ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null."); PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null"); "Input(Out@GRAD) should not be null");
auto *dx = ctx.Output<framework::Tensor>(framework::GradVarName("X")); ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
auto *x = ctx.Input<framework::Tensor>("X"); ctx->SetOutputDim(framework::GradVarName("Alpha"),
ctx->GetInputDim("Alpha"));
auto *dalpha =
ctx.Output<framework::Tensor>(framework::GradVarName("Alpha"));
auto *alpha = ctx.Input<framework::Tensor>("Alpha");
dx->Resize(x->dims());
dalpha->Resize(alpha->dims());
} }
}; };
......
...@@ -25,22 +25,21 @@ class RankLossOp : public framework::OperatorWithKernel { ...@@ -25,22 +25,21 @@ class RankLossOp : public framework::OperatorWithKernel {
: OperatorWithKernel(type, inputs, outputs, attrs) {} : OperatorWithKernel(type, inputs, outputs, attrs) {}
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(framework::InferShapeContextBase *ctx) const override {
// input check // input check
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) shouldn't be null");
"Input(Label) shouldn't be null"); PADDLE_ENFORCE(ctx->HasInput("Left"), "Input(Left) shouldn't be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Left"), PADDLE_ENFORCE(ctx->HasInput("Right"), "Input(Right) shouldn't be null");
"Input(Left) shouldn't be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Right"), auto label_dims = ctx->GetInputDim("Label");
"Input(Right) shouldn't be null"); auto left_dims = ctx->GetInputDim("Left");
auto label_dims = ctx.Input<framework::Tensor>("Label")->dims(); auto right_dims = ctx->GetInputDim("Right");
auto left_dims = ctx.Input<framework::Tensor>("Left")->dims();
auto right_dims = ctx.Input<framework::Tensor>("Right")->dims();
PADDLE_ENFORCE((label_dims == left_dims) && (left_dims == right_dims), PADDLE_ENFORCE((label_dims == left_dims) && (left_dims == right_dims),
"All inputs must have the same size"); "All inputs must have the same size");
PADDLE_ENFORCE((label_dims.size() == 2) && (label_dims[1] == 1), PADDLE_ENFORCE((label_dims.size() == 2) && (label_dims[1] == 1),
"All inputs must be row vector with size batch_size x 1."); "All inputs must be row vector with size batch_size x 1.");
ctx.Output<framework::Tensor>("Out")->Resize(label_dims); ctx->SetOutputDim("Out", label_dims);
} }
}; };
...@@ -91,25 +90,22 @@ class RankLossGradOp : public framework::OperatorWithKernel { ...@@ -91,25 +90,22 @@ class RankLossGradOp : public framework::OperatorWithKernel {
: OperatorWithKernel(type, inputs, outputs, attrs) {} : OperatorWithKernel(type, inputs, outputs, attrs) {}
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(framework::InferShapeContextBase *ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) shouldn't be null.");
"Input(Label) shouldn't be null."); PADDLE_ENFORCE(ctx->HasInput("Left"), "Input(Left) shouldn't be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Left"), PADDLE_ENFORCE(ctx->HasInput("Right"), "Input(Right) shouldn't be null.");
"Input(Left) shouldn't be null."); PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Right"), "Input(Out@GRAD) shouldn't be null.");
"Input(Right) shouldn't be null."); auto dims = ctx->GetInputDim("Left");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), auto left_grad_name = framework::GradVarName("Left");
"Input(Out@GRAD) shouldn't be null."); auto right_grad_name = framework::GradVarName("Right");
auto dims = ctx.Input<framework::Tensor>("Left")->dims();
auto *left_grad = if (ctx->HasOutput(left_grad_name)) {
ctx.Output<framework::Tensor>(framework::GradVarName("Left")); ctx->SetOutputDim(left_grad_name, dims);
auto *right_grad =
ctx.Output<framework::Tensor>(framework::GradVarName("Right"));
if (left_grad) {
left_grad->Resize(dims);
} }
if (right_grad) {
right_grad->Resize(dims); if (ctx->HasOutput(right_grad_name)) {
ctx->SetOutputDim(right_grad_name, dims);
} }
} }
}; };
......
...@@ -26,14 +26,14 @@ class ReshapeOp : public framework::OperatorWithKernel { ...@@ -26,14 +26,14 @@ class ReshapeOp : public framework::OperatorWithKernel {
: OperatorWithKernel(type, inputs, outputs, attrs) {} : OperatorWithKernel(type, inputs, outputs, attrs) {}
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(framework::InferShapeContextBase *ctx) const override {
// input check // input check
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of ReshapeOp should not be null."); "Input(X) of ReshapeOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of ReshapeOp should not be null."); "Output(Out) of ReshapeOp should not be null.");
auto shape = ctx.Attr<std::vector<int>>("shape"); auto shape = ctx->Attrs().Get<std::vector<int>>("shape");
PADDLE_ENFORCE(shape.size() > 0, "Attr(shape) shouldn't be empty."); PADDLE_ENFORCE(shape.size() > 0, "Attr(shape) shouldn't be empty.");
for (auto dim : shape) { for (auto dim : shape) {
PADDLE_ENFORCE(dim > 0, "Each dimension of shape must be positive."); PADDLE_ENFORCE(dim > 0, "Each dimension of shape must be positive.");
...@@ -41,8 +41,8 @@ class ReshapeOp : public framework::OperatorWithKernel { ...@@ -41,8 +41,8 @@ class ReshapeOp : public framework::OperatorWithKernel {
// capacity check // capacity check
int64_t capacity = int64_t capacity =
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>()); std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>());
auto *in = ctx.Input<framework::Tensor>("X"); auto x_dims = ctx->GetInputDim("X");
int64_t in_size = framework::product(in->dims()); int64_t in_size = framework::product(x_dims);
PADDLE_ENFORCE_EQ(capacity, in_size, PADDLE_ENFORCE_EQ(capacity, in_size,
"The size of Input(X) mismatches with Attr(shape)."); "The size of Input(X) mismatches with Attr(shape).");
// resize output // resize output
...@@ -50,11 +50,11 @@ class ReshapeOp : public framework::OperatorWithKernel { ...@@ -50,11 +50,11 @@ class ReshapeOp : public framework::OperatorWithKernel {
std::transform(shape.begin(), shape.end(), shape_int64.begin(), std::transform(shape.begin(), shape.end(), shape_int64.begin(),
[](int a) { return static_cast<int64_t>(a); }); [](int a) { return static_cast<int64_t>(a); });
auto out_dims = framework::make_ddim(shape_int64); auto out_dims = framework::make_ddim(shape_int64);
ctx.Output<framework::Tensor>("Out")->Resize(out_dims); ctx->SetOutputDim("Out", out_dims);
if (shape[0] == in->dims()[0]) { if (shape[0] == x_dims[0]) {
// Only pass LoD when the first dimension is equal between // Only pass LoD when the first dimension is equal between
// output and input. // output and input.
ctx.ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
} }
}; };
...@@ -76,7 +76,7 @@ Given a 2-D tensor X with 2 rows and 2 columns ...@@ -76,7 +76,7 @@ Given a 2-D tensor X with 2 rows and 2 columns
[[1, 2], [3, 4]] [[1, 2], [3, 4]]
with target shape = [1, 4], the reshape operator will transform with target shape = [1, 4], the reshape operator will transform
the tensor X into a 1-D tensor: the tensor X into a 1-D tensor:
[1, 2, 3, 4] [1, 2, 3, 4]
...@@ -94,13 +94,11 @@ class ReshapeGradOp : public framework::OperatorWithKernel { ...@@ -94,13 +94,11 @@ class ReshapeGradOp : public framework::OperatorWithKernel {
: OperatorWithKernel(type, inputs, outputs, attrs) {} : OperatorWithKernel(type, inputs, outputs, attrs) {}
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(framework::InferShapeContextBase *ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) shouldn't be null."); PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) shouldn't be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) shouldn't be null."); "Input(Out@GRAD) shouldn't be null.");
auto dims = ctx.Input<framework::Tensor>("X")->dims(); ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
auto *d_in = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
d_in->Resize(dims);
} }
}; };
......
...@@ -24,16 +24,16 @@ class RowwiseAddOp : public framework::OperatorWithKernel { ...@@ -24,16 +24,16 @@ class RowwiseAddOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of RowwiseAddOp should not be null."); "Input(X) of RowwiseAddOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("b"), PADDLE_ENFORCE(ctx->HasInput("b"),
"Input(b) of RowwiseAddOp should not be null."); "Input(b) of RowwiseAddOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of RowwiseAddOp should not be null."); "Output(Out) of RowwiseAddOp should not be null.");
auto x_dims = ctx.Input<Tensor>("X")->dims(); auto x_dims = ctx->GetInputDim("X");
auto b_dims = ctx.Input<Tensor>("b")->dims(); auto b_dims = ctx->GetInputDim("b");
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
x_dims.size(), b_dims.size(), x_dims.size(), b_dims.size(),
"The rank of input `X` must be larger than the one of input `b`."); "The rank of input `X` must be larger than the one of input `b`.");
...@@ -43,16 +43,17 @@ class RowwiseAddOp : public framework::OperatorWithKernel { ...@@ -43,16 +43,17 @@ class RowwiseAddOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
framework::slice_ddim(x_dims, num_col_dims, x_dims.size()), b_dims, framework::slice_ddim(x_dims, num_col_dims, x_dims.size()), b_dims,
"The width of two operands must be same"); "The width of two operands must be same");
PADDLE_ENFORCE_EQ(ctx.OutputSize("Out"), 1, "The output size must be 1"); PADDLE_ENFORCE_EQ(ctx->Outputs("Out").size(), 1,
ctx.Output<framework::Tensor>("Out")->Resize(x_dims); "The output size must be 1");
ctx.ShareLoD("X", /*->*/ "Out"); ctx->SetOutputDim("Out", x_dims);
ctx->ShareLoD("X", /*->*/ "Out");
} }
}; };
class RowwiseAddOpMaker : public framework::OpProtoAndCheckerMaker { class RowwiseAddOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
RowwiseAddOpMaker(framework::OpProto *proto, RowwiseAddOpMaker(framework::OpProto* proto,
framework::OpAttrChecker *op_checker) framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The left input of row-wise add op, must be matrix"); AddInput("X", "The left input of row-wise add op, must be matrix");
AddInput("b", "The right input of row-wise add op, must be vector"); AddInput("b", "The right input of row-wise add op, must be vector");
...@@ -69,25 +70,29 @@ class RowwiseAddGradOp : public framework::OperatorWithKernel { ...@@ -69,25 +70,29 @@ class RowwiseAddGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "X should not be null"); PADDLE_ENFORCE(ctx->HasInput("X"), "X should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("b"), "b should not be null"); PADDLE_ENFORCE(ctx->HasInput("b"), "b should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null"); "Input(Out@GRAD) should not be null");
auto x_dims = ctx.Input<Tensor>("X")->dims(); auto x_dims = ctx->GetInputDim("X");
auto b_dims = ctx.Input<Tensor>("b")->dims(); auto b_dims = ctx->GetInputDim("b");
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
x_dims.size(), b_dims.size(), x_dims.size(), b_dims.size(),
"The rank of input `X` must be larger than the one of input `b`."); "The rank of input `X` must be larger than the one of input `b`.");
int num_col_dims = x_dims.size() - b_dims.size(); int64_t num_col_dims = x_dims.size() - b_dims.size();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
framework::slice_ddim(x_dims, num_col_dims, x_dims.size()), b_dims, framework::slice_ddim(x_dims, num_col_dims, x_dims.size()), b_dims,
"The width of two operands must be same"); "The width of two operands must be same");
auto *dx = ctx.Output<framework::Tensor>(framework::GradVarName("X")); auto x_grad_name = framework::GradVarName("X");
auto *db = ctx.Output<framework::Tensor>(framework::GradVarName("b")); auto b_grad_name = framework::GradVarName("b");
if (dx) dx->Resize(x_dims); if (ctx->HasOutput(x_grad_name)) {
if (db) db->Resize(b_dims); ctx->SetOutputDim(x_grad_name, x_dims);
}
if (ctx->HasOutput(b_grad_name)) {
ctx->SetOutputDim(b_grad_name, b_dims);
}
} }
}; };
......
...@@ -26,16 +26,13 @@ class ScaleOp : public framework::OperatorWithKernel { ...@@ -26,16 +26,13 @@ class ScaleOp : public framework::OperatorWithKernel {
: OperatorWithKernel(type, inputs, outputs, attrs) {} : OperatorWithKernel(type, inputs, outputs, attrs) {}
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(framework::InferShapeContextBase *ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of ScaleOp should not be null."); "Input(X) of ScaleOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of ScaleOp should not be null."); "Output(Out) of ScaleOp should not be null.");
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
auto *in = ctx.Input<framework::Tensor>("X"); ctx->ShareLoD("X", /*->*/ "Out");
auto *out = ctx.Output<framework::Tensor>("Out");
out->Resize(in->dims());
ctx.ShareLoD("X", /*->*/ "Out");
} }
}; };
......
...@@ -23,29 +23,30 @@ class ScatterOp : public framework::OperatorWithKernel { ...@@ -23,29 +23,30 @@ class ScatterOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Ref"), PADDLE_ENFORCE(ctx->HasInput("Ref"),
"Input(Ref) of ScatterOp should not be null."); "Input(Ref) of ScatterOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Index"), PADDLE_ENFORCE(ctx->HasInput("Index"),
"Input(Index) of ScatterOp should not be null."); "Input(Index) of ScatterOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Updates"), PADDLE_ENFORCE(ctx->HasInput("Updates"),
"Input(Updates) of ScatterOp should not be null."); "Input(Updates) of ScatterOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of ScatterOp should not be null."); "Output(Out) of ScatterOp should not be null.");
PADDLE_ENFORCE_EQ(ctx.Input<Tensor>("Index")->dims().size(), 1, auto updates_dims = ctx->GetInputDim("Updates");
auto ref_dims = ctx->GetInputDim("Ref");
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Index").size(), 1,
"Update Index should be 1-D."); "Update Index should be 1-D.");
PADDLE_ENFORCE_EQ(ctx.Input<Tensor>("Ref")->dims().size(), PADDLE_ENFORCE_EQ(ref_dims.size(), updates_dims.size(),
ctx.Input<Tensor>("Updates")->dims().size(),
"Reference and Updates should have the same shape size"); "Reference and Updates should have the same shape size");
PADDLE_ENFORCE_EQ(ctx.Input<Tensor>("Updates")->dims()[0], PADDLE_ENFORCE_EQ(ctx->GetInputDim("Updates")[0],
ctx.Input<Tensor>("Index")->dims()[0], ctx->GetInputDim("Index")[0],
"Updates and Index should have same batch-size."); "Updates and Index should have same batch-size.");
framework::DDim data_dim(ctx.Input<Tensor>("Updates")->dims()); framework::DDim data_dim(updates_dims);
for (int i = 1; i < data_dim.size(); ++i) for (int i = 1; i < data_dim.size(); ++i) {
PADDLE_ENFORCE_EQ(data_dim[i], ctx.Input<Tensor>("Updates")->dims()[i]); PADDLE_ENFORCE_EQ(data_dim[i], updates_dims[i]);
ctx.Output<framework::Tensor>("Out")->Resize( }
ctx.Input<Tensor>("Ref")->dims()); ctx->SetOutputDim("Out", ref_dims);
} }
}; };
...@@ -54,22 +55,17 @@ class ScatterGradOp : public framework::OperatorWithKernel { ...@@ -54,22 +55,17 @@ class ScatterGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(framework::InferShapeContextBase* ctx) const override {
auto *dUpdates = ctx->SetOutputDim(framework::GradVarName("Updates"),
ctx.Output<framework::Tensor>(framework::GradVarName("Updates")); ctx->GetInputDim("Updates"));
auto *Updates = ctx.Input<Tensor>("Updates"); ctx->SetOutputDim(framework::GradVarName("Ref"), ctx->GetInputDim("Ref"));
auto *dRef = ctx.Output<framework::Tensor>(framework::GradVarName("Ref"));
auto *Ref = ctx.Input<Tensor>("Ref");
dRef->Resize(Ref->dims());
dUpdates->Resize(Updates->dims());
} }
}; };
class ScatterOpMaker : public framework::OpProtoAndCheckerMaker { class ScatterOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ScatterOpMaker(framework::OpProto *proto, ScatterOpMaker(framework::OpProto* proto,
framework::OpAttrChecker *op_checker) framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Ref", "The source input of scatter op"); AddInput("Ref", "The source input of scatter op");
AddInput("Index", AddInput("Index",
...@@ -77,13 +73,14 @@ class ScatterOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -77,13 +73,14 @@ class ScatterOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Updates", "The updated value of updates op"); AddInput("Updates", "The updated value of updates op");
AddOutput("Out", "The output of add op"); AddOutput("Out", "The output of add op");
AddComment(R"DOC( AddComment(R"DOC(
Scatter Operator by selecting from the first axis, Scatter Operator by selecting from the first axis,
Out = Ref Out = Ref
Out[Index] = Ref[Index] + Updates Out[Index] = Ref[Index] + Updates
)DOC"); )DOC");
} }
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
......
...@@ -22,23 +22,12 @@ class SequencePoolOp : public framework::OperatorWithKernel { ...@@ -22,23 +22,12 @@ class SequencePoolOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext& ctx) const override { void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SequencePoolOp should not be null."); "Input(X) of SequenceAvgPoolOp should not be null.");
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE(ctx->HasOutput("Out"),
ctx.OutputVar("Out"), "Output(Out) of SequenceAvgPoolOp should not be null.");
"Output(Out) of SequencePoolOp should not be null."); ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
auto* x = ctx.Input<framework::LoDTensor>("X");
auto dims = x->dims();
auto lod = x->lod();
PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now.");
PADDLE_ENFORCE_GE(
dims[0],
/*batch size = */ static_cast<int64_t>(lod[0].size() - 1),
"The first dimension of Input(X) must be large than batch size.");
dims[0] = lod[0].size() - 1;
ctx.Output<framework::LoDTensor>("Out")->Resize({dims});
} }
}; };
...@@ -61,17 +50,17 @@ class SequencePoolOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -61,17 +50,17 @@ class SequencePoolOpMaker : public framework::OpProtoAndCheckerMaker {
SequencePoolOp pools features of all time-steps of each instance. SequencePoolOp pools features of all time-steps of each instance.
For a mini-batch of 3 variable lengths sentences, containing 2, 3, and 2 time-steps: For a mini-batch of 3 variable lengths sentences, containing 2, 3, and 2 time-steps:
Assume X is a [7,M,N] float LoDTensor, and X->lod()[0] = [0, 2, 5, 7]. Assume X is a [7,M,N] float LoDTensor, and X->lod()[0] = [0, 2, 5, 7].
Besides, for the sake of simplicity, we assume M=1 and N=1, Besides, for the sake of simplicity, we assume M=1 and N=1,
and the value of X = [[1, 3], [2, 4, 6], [5, 1]]. and the value of X = [[1, 3], [2, 4, 6], [5, 1]].
Thus, Out is a [3,1,1] float LoDTensor, but Out->lod() is nullptr. Thus, Out is a [3,1,1] float LoDTensor, but Out->lod() is nullptr.
And for different strategy, the value of Out is as follows: And for different strategy, the value of Out is as follows:
- AVERAGE: [2, 4, 3], where 2=(1+3)/2, 4=(2+4+6)/3, 3=(5+1)/2 - AVERAGE: [2, 4, 3], where 2=(1+3)/2, 4=(2+4+6)/3, 3=(5+1)/2
- SUM: [4, 12, 6], where 4=1+3, 12=2+4+6, 6=5+1 - SUM: [4, 12, 6], where 4=1+3, 12=2+4+6, 6=5+1
- SQRT: [2.82, 6.93, 4.24], where 2.82=(1+3)/sqrt(2), - SQRT: [2.82, 6.93, 4.24], where 2.82=(1+3)/sqrt(2),
6.93=(2+4+6)/sqrt(3), 4.24=(5+1)/sqrt(2) 6.93=(2+4+6)/sqrt(3), 4.24=(5+1)/sqrt(2)
- MAX: [3, 6, 5], where 3=max(1,3), 6=max(2,4,6), 5=max(5,1) - MAX: [3, 6, 5], where 3=max(1,3), 6=max(2,4,6), 5=max(5,1)
- LAST: [3, 6, 1], where 3=last(1,3), 6=last(2,4,6), 1=last(5,1) - LAST: [3, 6, 1], where 3=last(1,3), 6=last(2,4,6), 1=last(5,1)
...@@ -85,22 +74,18 @@ class SequencePoolGradOp : public framework::OperatorWithKernel { ...@@ -85,22 +74,18 @@ class SequencePoolGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext& ctx) const override { void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Gradient of Out should not be null."); "Gradient of Out should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), PADDLE_ENFORCE(ctx->HasInput("X"), "The input X should not be null.");
"The input X should not be null."); auto og_dims = ctx->GetInputDim(framework::GradVarName("Out"));
auto og_dims = auto x_dims = ctx->GetInputDim("X");
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))->dims();
auto x_dims = ctx.Input<framework::LoDTensor>("X")->dims();
PADDLE_ENFORCE_EQ(og_dims.size(), x_dims.size(), PADDLE_ENFORCE_EQ(og_dims.size(), x_dims.size(),
"The rank of output grad must equal to Input(X)."); "The rank of output grad must equal to Input(X).");
for (int64_t i = 1; i < og_dims.size(); ++i) { for (int64_t i = 1; i < og_dims.size(); ++i) {
PADDLE_ENFORCE_EQ(og_dims[i], x_dims[i], "The dimension mismatch."); PADDLE_ENFORCE_EQ(og_dims[i], x_dims[i], "The dimension mismatch.");
} }
auto* x_grad = ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
x_grad->Resize(x_dims);
} }
}; };
......
...@@ -46,16 +46,27 @@ class SequencePoolKernel : public framework::OpKernel { ...@@ -46,16 +46,27 @@ class SequencePoolKernel : public framework::OpKernel {
int strategy = context.Attr<int>("strategy"); int strategy = context.Attr<int>("strategy");
auto dims = in->dims(); auto dims = in->dims();
auto lod = in->lod()[0]; auto lod = in->lod();
int64_t w = in->numel() / dims[0]; int64_t w = in->numel() / dims[0];
// InferShape by lod
PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now.");
PADDLE_ENFORCE_GE(
dims[0],
/*batch size = */ static_cast<int64_t>(lod[0].size() - 1),
"The first dimension of Input(X) must be large than batch size.");
dims[0] = lod[0].size() - 1;
out->Resize({dims});
auto lod_level_0 = lod[0];
out->mutable_data<T>(context.GetPlace()); out->mutable_data<T>(context.GetPlace());
auto place = context.GetEigenDevice<Place>(); auto place = context.GetEigenDevice<Place>();
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) { for (int i = 0; i < static_cast<int>(lod_level_0.size()) - 1; ++i) {
Tensor in_t = Tensor in_t = in->Slice<T>(static_cast<int>(lod_level_0[i]),
in->Slice<T>(static_cast<int>(lod[i]), static_cast<int>(lod[i + 1])); static_cast<int>(lod_level_0[i + 1]));
Tensor out_t = out->Slice<T>(i, i + 1); Tensor out_t = out->Slice<T>(i, i + 1);
int64_t h = static_cast<int64_t>(lod[i + 1] - lod[i]); int64_t h = static_cast<int64_t>(lod_level_0[i + 1] - lod_level_0[i]);
auto in_e = EigenMatrix<T>::From(in_t, framework::make_ddim({h, w})); auto in_e = EigenMatrix<T>::From(in_t, framework::make_ddim({h, w}));
auto out_e = EigenVector<T>::Flatten(out_t); auto out_e = EigenVector<T>::Flatten(out_t);
......
...@@ -22,19 +22,18 @@ class SGDOp : public framework::OperatorWithKernel { ...@@ -22,19 +22,18 @@ class SGDOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(framework::InferShapeContextBase *ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("param"), PADDLE_ENFORCE(ctx->HasInput("param"),
"Input(param) of SGDOp should not be null."); "Input(param) of SGDOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("grad"), PADDLE_ENFORCE(ctx->HasInput("grad"),
"Input(grad) of SGDOp should not be null."); "Input(grad) of SGDOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("param_out"), PADDLE_ENFORCE(ctx->HasOutput("param_out"),
"Output(param_out) of SGDOp should not be null."); "Output(param_out) of SGDOp should not be null.");
PADDLE_ENFORCE_EQ(ctx.Input<Tensor>("param")->dims(), auto param_dim = ctx->GetInputDim("param");
ctx.Input<Tensor>("grad")->dims(), PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("grad"),
"Two input of SGD Op's dimension must be same."); "Two input of SGD Op's dimension must be same.");
ctx.Output<framework::Tensor>("param_out") ctx->SetOutputDim("param_out", param_dim);
->Resize(ctx.Input<Tensor>("param")->dims());
} }
}; };
......
...@@ -22,33 +22,28 @@ class SmoothL1LossOp : public framework::OperatorWithKernel { ...@@ -22,33 +22,28 @@ class SmoothL1LossOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext& ctx) const override { void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "X must be initialized."); PADDLE_ENFORCE(ctx->HasInput("X"), "X must be initialized.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Y must be initialized."); PADDLE_ENFORCE(ctx->HasInput("Y"), "Y must be initialized.");
auto* x = ctx.Input<framework::Tensor>("X"); auto x_dims = ctx->GetInputDim("X");
auto* y = ctx.Input<framework::Tensor>("Y"); auto y_dims = ctx->GetInputDim("Y");
PADDLE_ENFORCE_EQ(x->dims(), y->dims(), PADDLE_ENFORCE_EQ(x_dims, y_dims, "The shape of X and Y must be the same.");
"The shape of X and Y must be the same."); PADDLE_ENFORCE_GE(x_dims.size(), 2,
PADDLE_ENFORCE_GE(x->dims().size(), 2,
"The tensor rank of X must be at least 2."); "The tensor rank of X must be at least 2.");
auto* inside_weight = ctx.Input<framework::Tensor>("InsideWeight"); if (ctx->HasInput("InsideWeight")) {
if (inside_weight) { PADDLE_ENFORCE(ctx->HasInput("OutsideWeight"),
auto* outside_weight = ctx.Input<framework::Tensor>("OutsideWeight"); "If weights are provided, must specify both "
PADDLE_ENFORCE_NOT_NULL(outside_weight, "inside and outside weights.");
"If weights are provided, must specify both " PADDLE_ENFORCE_EQ(ctx->GetInputDim("InsideWeight"), x_dims,
"inside and outside weights.");
PADDLE_ENFORCE_EQ(inside_weight->dims(), x->dims(),
"The shape of InsideWeight must be same as X."); "The shape of InsideWeight must be same as X.");
PADDLE_ENFORCE_EQ(outside_weight->dims(), x->dims(), PADDLE_ENFORCE_EQ(ctx->GetInputDim("OutsideWeight"), x_dims,
"The shape of OutsideWeight must be same as X."); "The shape of OutsideWeight must be same as X.");
} }
auto* diff = ctx.Output<framework::Tensor>("Diff"); ctx->SetOutputDim("Diff", x_dims);
auto* out = ctx.Output<framework::Tensor>("Out");
diff->Resize(x->dims());
// loss is a two-rank tensor // loss is a two-rank tensor
out->Resize({x->dims()[0], 1}); ctx->SetOutputDim("Out", {x_dims[0], 1});
} }
}; };
...@@ -99,12 +94,9 @@ class SmoothL1LossGradOp : public framework::OperatorWithKernel { ...@@ -99,12 +94,9 @@ class SmoothL1LossGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext& ctx) const override { void InferShape(framework::InferShapeContextBase* ctx) const override {
auto in_dims = ctx.Input<framework::Tensor>("X")->dims(); auto in_dims = ctx->GetInputDim("X");
auto out_dims = auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
ctx.Input<framework::Tensor>(framework::GradVarName("Out"))->dims();
auto* x_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto* y_grad = ctx.Output<framework::Tensor>(framework::GradVarName("Y"));
PADDLE_ENFORCE_GE(out_dims.size(), 2, PADDLE_ENFORCE_GE(out_dims.size(), 2,
"The tensor rank of Input(Out@Grad) should be 2."); "The tensor rank of Input(Out@Grad) should be 2.");
...@@ -114,8 +106,14 @@ class SmoothL1LossGradOp : public framework::OperatorWithKernel { ...@@ -114,8 +106,14 @@ class SmoothL1LossGradOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(out_dims[1], 1, PADDLE_ENFORCE_EQ(out_dims[1], 1,
"The 2nd dimension of Input(Out@Grad) must be 1."); "The 2nd dimension of Input(Out@Grad) must be 1.");
if (x_grad) x_grad->Resize(in_dims); auto x_grad_name = framework::GradVarName("X");
if (y_grad) y_grad->Resize(in_dims); auto y_grad_name = framework::GradVarName("Y");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, in_dims);
}
if (ctx->HasOutput(y_grad_name)) {
ctx->SetOutputDim(y_grad_name, in_dims);
}
} }
}; };
......
...@@ -22,22 +22,23 @@ class SoftmaxOp : public framework::OperatorWithKernel { ...@@ -22,22 +22,23 @@ class SoftmaxOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SoftmaxOp should not be null."); "Input(X) of SoftmaxOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Y"), PADDLE_ENFORCE(ctx->HasOutput("Y"),
"Output(Y) of SoftmaxOp should not be null."); "Output(Y) of SoftmaxOp should not be null.");
PADDLE_ENFORCE(ctx.Input<Tensor>("X")->dims().size() == 2UL, auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE(x_dims.size() == 2UL,
"The input of softmax op must be a matrix."); "The input of softmax op must be a matrix.");
ctx.Output<framework::Tensor>("Y")->Resize(ctx.Input<Tensor>("X")->dims()); ctx->SetOutputDim("Y", x_dims);
} }
}; };
class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker { class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SoftmaxOpMaker(framework::OpProto *proto, SoftmaxOpMaker(framework::OpProto* proto,
framework::OpAttrChecker *op_checker) framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"The input tensor of softmax. " "The input tensor of softmax. "
...@@ -68,16 +69,15 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { ...@@ -68,16 +69,15 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) should be not null."); PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should be not null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Y")), PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")),
"Input(Y@GRAD) should be not null."); "Input(Y@GRAD) should be not null.");
PADDLE_ENFORCE_EQ(ctx.Input<Tensor>("Y")->dims(), PADDLE_ENFORCE_EQ(ctx->GetInputDim("Y"),
ctx.Input<Tensor>(framework::GradVarName("Y"))->dims(), ctx->GetInputDim(framework::GradVarName("Y")),
"Input(Y) and its gradients should have a same shape."); "Input(Y) and its gradients should have a same shape.");
ctx.Output<framework::Tensor>(framework::GradVarName("X")) ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
->Resize(ctx.Input<Tensor>("X")->dims());
} }
}; };
......
...@@ -24,40 +24,42 @@ class SplitOp : public framework::OperatorWithKernel { ...@@ -24,40 +24,42 @@ class SplitOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(framework::InferShapeContextBase *ctx) const override {
// infershape auto in_dims = ctx->GetInputDim("X");
auto *in = ctx.Input<framework::Tensor>("X"); auto outs_names = ctx->Outputs("Out");
auto outs = ctx.MultiOutput<framework::Tensor>("Out"); size_t axis = static_cast<size_t>(ctx->Attrs().Get<int>("axis"));
size_t axis = static_cast<size_t>(ctx.Attr<int>("axis")); size_t num = static_cast<size_t>(ctx->Attrs().Get<int>("num"));
size_t num = static_cast<size_t>(ctx.Attr<int>("num")); std::vector<int> sections = static_cast<std::vector<int>>(
std::vector<int> sections = ctx->Attrs().Get<std::vector<int>>("sections"));
static_cast<std::vector<int>>(ctx.Attr<std::vector<int>>("sections")); const size_t outs_number = outs_names.size();
const size_t n = outs.size(); std::vector<framework::DDim> outs_dims;
outs_dims.reserve(outs_number);
if (num > 0) { if (num > 0) {
int64_t in_axis_dim = in->dims()[axis]; int64_t in_axis_dim = in_dims[axis];
PADDLE_ENFORCE_EQ(in_axis_dim % num, 0, PADDLE_ENFORCE_EQ(in_axis_dim % num, 0,
"tensor split does not result" "tensor split does not result"
" in an equal division"); " in an equal division");
size_t out_axis_dim = in_axis_dim / num; size_t out_axis_dim = in_axis_dim / num;
for (size_t i = 0; i < n; ++i) { for (size_t i = 0; i < outs_number; ++i) {
auto dim = in->dims(); auto dim = in_dims;
dim[axis] = out_axis_dim; dim[axis] = out_axis_dim;
outs[i]->Resize(dim); outs_dims.push_back(dim);
} }
} else if (sections.size() > 0) { } else if (sections.size() > 0) {
PADDLE_ENFORCE_EQ(sections.size(), n, PADDLE_ENFORCE_EQ(sections.size(), outs_number,
"tensor split sections size" "tensor split sections size"
"should be equal to output size."); "should be equal to output size.");
for (size_t i = 0; i < n; ++i) { for (size_t i = 0; i < outs_number; ++i) {
auto dim = in->dims(); auto dim = in_dims;
dim[axis] = sections[i]; dim[axis] = sections[i];
outs[i]->Resize(dim); outs_dims.push_back(dim);
} }
} else { } else {
PADDLE_ENFORCE_NOT_NULL(nullptr, "split operator should", PADDLE_ENFORCE_NOT_NULL(nullptr, "split operator should",
" specify indices or sections."); " specify indices or sections.");
} }
ctx->SetOutputsDim("Out", outs_dims);
} }
}; };
......
...@@ -22,24 +22,19 @@ class SquaredL2DistanceOp : public framework::OperatorWithKernel { ...@@ -22,24 +22,19 @@ class SquaredL2DistanceOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext& ctx) const override { void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE(ctx->HasInput("X"),
ctx.InputVar("X"), "Input(X) of SquaredL2DistanceOp should not be null.");
"Input(X) of SquaredL2DistanceOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("Y"),
PADDLE_ENFORCE_NOT_NULL( "Input(Y) of SquaredL2DistanceOp should not be null.");
ctx.InputVar("Y"), PADDLE_ENFORCE(
"Input(Y) of SquaredL2DistanceOp should not be null."); ctx->HasOutput("sub_result"),
PADDLE_ENFORCE_NOT_NULL(
ctx.OutputVar("sub_result"),
"Output(sub_result) of SquaredL2DistanceOp should not be null."); "Output(sub_result) of SquaredL2DistanceOp should not be null.");
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE(ctx->HasOutput("Out"),
ctx.OutputVar("Out"), "Output(Out) of SquaredL2DistanceOp should not be null.");
"Output(Out) of SquaredL2DistanceOp should not be null.");
auto* x = ctx.Input<Tensor>("X"); auto x_dims = ctx->GetInputDim("X");
auto x_dims = x->dims(); auto y_dims = ctx->GetInputDim("Y");
auto* y = ctx.Input<Tensor>("Y");
auto y_dims = y->dims();
PADDLE_ENFORCE_EQ(framework::arity(x_dims), framework::arity(y_dims), PADDLE_ENFORCE_EQ(framework::arity(x_dims), framework::arity(y_dims),
"Tensor rank of both SquaredL2DistanceOp's " "Tensor rank of both SquaredL2DistanceOp's "
...@@ -47,17 +42,16 @@ class SquaredL2DistanceOp : public framework::OperatorWithKernel { ...@@ -47,17 +42,16 @@ class SquaredL2DistanceOp : public framework::OperatorWithKernel {
int rank = framework::arity(x_dims); int rank = framework::arity(x_dims);
PADDLE_ENFORCE_GE(rank, 2, "Tensor rank should be at least equal to 2."); PADDLE_ENFORCE_GE(rank, 2, "Tensor rank should be at least equal to 2.");
PADDLE_ENFORCE_EQ(x->numel() / x_dims[0], y->numel() / y_dims[0], PADDLE_ENFORCE_EQ(product(x_dims) / x_dims[0], product(y_dims) / y_dims[0],
"Product of dimensions expcet the first dimension of " "Product of dimensions expcet the first dimension of "
"input and target must be equal."); "input and target must be equal.");
PADDLE_ENFORCE(y_dims[0] == 1 || y_dims[0] == x_dims[0], PADDLE_ENFORCE(y_dims[0] == 1 || y_dims[0] == x_dims[0],
"First dimension of target must be equal to input " "First dimension of target must be equal to input "
"or to 1."); "or to 1.");
ctx.Output<framework::Tensor>("sub_result") ctx->SetOutputDim("sub_result", {x_dims[0], product(x_dims) / x_dims[0]});
->Resize({x_dims[0], x->numel() / x_dims[0]}); ctx->SetOutputDim("Out", {x_dims[0], 1});
ctx.Output<framework::Tensor>("Out")->Resize({x_dims[0], 1}); ctx->ShareLoD("X", /*->*/ "Out");
ctx.ShareLoD("X", /*->*/ "Out");
} }
}; };
...@@ -92,22 +86,22 @@ class SquaredL2DistanceGradOp : public framework::OperatorWithKernel { ...@@ -92,22 +86,22 @@ class SquaredL2DistanceGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext& ctx) const override { void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Gradient of Out should not be null"); "Gradient of Out should not be null");
auto out_dims = ctx.Input<Tensor>(framework::GradVarName("Out"))->dims(); auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
auto x_dims = ctx.Input<Tensor>("X")->dims(); auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx.Input<Tensor>("Y")->dims(); auto y_dims = ctx->GetInputDim("Y");
PADDLE_ENFORCE_EQ(out_dims[0], x_dims[0], PADDLE_ENFORCE_EQ(out_dims[0], x_dims[0],
"First dimension of output gradient and " "First dimension of output gradient and "
"input value must be equal."); "input value must be equal.");
PADDLE_ENFORCE_EQ(out_dims[1], 1, PADDLE_ENFORCE_EQ(out_dims[1], 1,
"Second dimension of output gradient " "Second dimension of output gradient "
"must be 1."); "must be 1.");
auto* x_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X")); auto x_grad_name = framework::GradVarName("X");
auto* y_grad = ctx.Output<framework::Tensor>(framework::GradVarName("Y")); auto y_grad_name = framework::GradVarName("Y");
if (x_grad) x_grad->Resize(x_dims); if (ctx->HasOutput(x_grad_name)) ctx->SetOutputDim(x_grad_name, x_dims);
if (y_grad) y_grad->Resize(y_dims); if (ctx->HasOutput(y_grad_name)) ctx->SetOutputDim(y_grad_name, y_dims);
} }
}; };
......
...@@ -21,31 +21,27 @@ class SumOp : public framework::OperatorWithKernel { ...@@ -21,31 +21,27 @@ class SumOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE(!ctx.MultiInputVar("X").empty(), auto x_dims = ctx->GetInputsDim("X");
"Input(X) of SumOp should not be null."); PADDLE_ENFORCE(!x_dims.empty(), "Input(X) of SumOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of SumOp should not be null."); "Output(Out) of SumOp should not be null.");
auto ins = ctx.MultiInput<framework::Tensor>("X");
auto *out = ctx.Output<framework::Tensor>("Out");
int N = ins.size();
auto in_dim = ins[0]->dims();
auto in_dim = x_dims[0];
size_t N = x_dims.size();
PADDLE_ENFORCE_GT(N, 1, "Input tensors count should > 1."); PADDLE_ENFORCE_GT(N, 1, "Input tensors count should > 1.");
for (int i = 1; i < N; i++) { for (size_t i = 1; i < N; i++) {
auto dim = ins[i]->dims(); auto dim = x_dims[i];
PADDLE_ENFORCE(in_dim == dim, "Input tensors must have same shape"); PADDLE_ENFORCE(in_dim == dim, "Input tensors must have same shape");
} }
out->Resize(in_dim); ctx->SetOutputDim("Out", in_dim);
ctx.ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
}; };
class SumOpMaker : public framework::OpProtoAndCheckerMaker { class SumOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SumOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) SumOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "the input tensors of sum operator.").AsDuplicable(); AddInput("X", "the input tensors of sum operator.").AsDuplicable();
AddOutput("Out", "the output tensor of sum operator."); AddOutput("Out", "the output tensor of sum operator.");
...@@ -63,13 +59,16 @@ class SumGradOp : public framework::OperatorWithKernel { ...@@ -63,13 +59,16 @@ class SumGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(framework::InferShapeContextBase* ctx) const override {
auto outputs = auto out_grad_dims = ctx->GetInputDim(framework::GradVarName("Out"));
ctx.MultiOutput<framework::Tensor>(framework::GradVarName("X")); auto x_grad_names = ctx->Outputs(framework::GradVarName("X"));
auto dims = ctx.Input<Tensor>(framework::GradVarName("Out"))->dims(); size_t x_length = x_grad_names.size();
for (auto output : outputs) { std::vector<framework::DDim> x_grad_dims;
output->Resize(dims); x_grad_dims.reserve(x_length);
for (size_t i = 0; i < x_length; ++i) {
x_grad_dims.push_back(out_grad_dims);
} }
ctx->SetOutputsDim(framework::GradVarName("X"), x_grad_dims);
} }
}; };
......
...@@ -22,26 +22,26 @@ class TopkOp : public framework::OperatorWithKernel { ...@@ -22,26 +22,26 @@ class TopkOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(framework::InferShapeContextBase *ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of TopkOp should not be null."); "Input(X) of TopkOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of TopkOp should not be null."); "Output(Out) of TopkOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Indices"), PADDLE_ENFORCE(ctx->HasOutput("Indices"),
"Output(Indices) of TopkOp should not be null."); "Output(Indices) of TopkOp should not be null.");
auto *input = ctx.Input<framework::Tensor>("X"); auto input_dims = ctx->GetInputDim("X");
const int k = static_cast<int>(ctx.Attr<int>("k")); const int k = static_cast<int>(ctx->Attrs().Get<int>("k"));
PADDLE_ENFORCE_GE(k, 1, "k must >= 1"); PADDLE_ENFORCE_GE(k, 1, "k must >= 1");
PADDLE_ENFORCE_GE(input->dims().size(), 1, "input must have >= 1d shape"); PADDLE_ENFORCE_GE(input_dims.size(), 1, "input must have >= 1d shape");
PADDLE_ENFORCE_GE(input->dims()[input->dims().size() - 1], k, PADDLE_ENFORCE_GE(input_dims[input_dims.size() - 1], k,
"input must have >= k columns"); "input must have >= k columns");
framework::DDim dims = input->dims(); framework::DDim dims = input_dims;
dims[dims.size() - 1] = k; dims[dims.size() - 1] = k;
ctx.Output<framework::Tensor>("Out")->Resize(dims); ctx->SetOutputDim("Out", dims);
ctx.Output<framework::Tensor>("Indices")->Resize(dims); ctx->SetOutputDim("Indices", dims);
} }
}; };
......
...@@ -24,12 +24,11 @@ class TransposeOp : public framework::OperatorWithKernel { ...@@ -24,12 +24,11 @@ class TransposeOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null"); PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null");
"Output(Out) should not be null"); auto x_dims = ctx->GetInputDim("X");
auto x_dims = ctx.Input<Tensor>("X")->dims(); std::vector<int> axis = ctx->Attrs().Get<std::vector<int>>("axis");
std::vector<int> axis = ctx.Attr<std::vector<int>>("axis");
size_t x_rank = x_dims.size(); size_t x_rank = x_dims.size();
size_t axis_size = axis.size(); size_t axis_size = axis.size();
...@@ -51,14 +50,14 @@ class TransposeOp : public framework::OperatorWithKernel { ...@@ -51,14 +50,14 @@ class TransposeOp : public framework::OperatorWithKernel {
for (size_t i = 0; i < axis_size; i++) { for (size_t i = 0; i < axis_size; i++) {
out_dims[i] = x_dims[axis[i]]; out_dims[i] = x_dims[axis[i]];
} }
ctx.Output<framework::Tensor>("Out")->Resize(out_dims); ctx->SetOutputDim("Out", out_dims);
} }
}; };
class TransposeOpMaker : public framework::OpProtoAndCheckerMaker { class TransposeOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
TransposeOpMaker(framework::OpProto *proto, TransposeOpMaker(framework::OpProto* proto,
framework::OpAttrChecker *op_checker) framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput( AddInput(
"X", "X",
...@@ -79,7 +78,7 @@ For example: ...@@ -79,7 +78,7 @@ For example:
[3, 4, 5]]) [3, 4, 5]])
>> axis = [1, 0] >> axis = [1, 0]
>> output = input.transpose(axis) >> output = input.transpose(axis)
>> output >> output
array([[0, 3], array([[0, 3],
[1, 4], [1, 4],
[2, 5]]) [2, 5]])
...@@ -94,14 +93,15 @@ class TransposeOpGrad : public framework::OperatorWithKernel { ...@@ -94,14 +93,15 @@ class TransposeOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null"); PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null"); "Input(Out@GRAD) should not be null");
auto x_dims = ctx.Input<Tensor>("X")->dims(); auto x_dims = ctx->GetInputDim("X");
auto *x_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X")); ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
if (ctx->HasOutput(framework::GradVarName("X"))) {
if (x_grad) x_grad->Resize(x_dims); ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
}
} }
}; };
......
...@@ -23,18 +23,18 @@ namespace operators { ...@@ -23,18 +23,18 @@ namespace operators {
template <typename T> template <typename T>
class CPUUniformRandomKernel : public framework::OpKernel { class CPUUniformRandomKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* tensor = context.Output<framework::Tensor>("Out"); auto* tensor = ctx.Output<framework::Tensor>("Out");
T* data = tensor->mutable_data<T>(context.GetPlace()); T* data = tensor->mutable_data<T>(ctx.GetPlace());
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed")); unsigned int seed = static_cast<unsigned int>(ctx.Attr<int>("seed"));
std::minstd_rand engine; std::minstd_rand engine;
if (seed == 0) { if (seed == 0) {
seed = std::random_device()(); seed = std::random_device()();
} }
engine.seed(seed); engine.seed(seed);
std::uniform_real_distribution<T> dist( std::uniform_real_distribution<T> dist(
static_cast<T>(context.Attr<float>("min")), static_cast<T>(ctx.Attr<float>("min")),
static_cast<T>(context.Attr<float>("max"))); static_cast<T>(ctx.Attr<float>("max")));
int64_t size = tensor->numel(); int64_t size = tensor->numel();
for (int64_t i = 0; i < size; ++i) { for (int64_t i = 0; i < size; ++i) {
data[i] = dist(engine); data[i] = dist(engine);
...@@ -47,21 +47,20 @@ class UniformRandomOp : public framework::OperatorWithKernel { ...@@ -47,21 +47,20 @@ class UniformRandomOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext& ctx) const override { void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE(ctx->HasOutput("Out"),
ctx.OutputVar("Out"), "Output(Out) of UniformRandomOp should not be null.");
"Output(Out) of UniformRandomOp should not be null.");
PADDLE_ENFORCE(Attr<float>("min") < Attr<float>("max"), PADDLE_ENFORCE(
"uniform_random's min must less then max"); ctx->Attrs().Get<float>("min") < ctx->Attrs().Get<float>("max"),
auto* tensor = ctx.Output<framework::Tensor>("Out"); "uniform_random's min must less then max");
auto dims = Attr<std::vector<int>>("dims"); auto dims = Attr<std::vector<int>>("dims");
std::vector<int64_t> temp; std::vector<int64_t> temp;
temp.reserve(dims.size()); temp.reserve(dims.size());
for (auto dim : dims) { for (auto dim : dims) {
temp.push_back(static_cast<int64_t>(dim)); temp.push_back(static_cast<int64_t>(dim));
} }
tensor->Resize(framework::make_ddim(temp)); ctx->SetOutputDim("Out", framework::make_ddim(temp));
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册