提交 0d9ba3da 编写于 作者: Y yangyaming

Adapt to new interface.

上级 7be390aa
......@@ -24,26 +24,28 @@ class ExpandOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext& ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "X must be initialized.");
std::vector<int> expand_times = Attr<std::vector<int>>("expandTimes");
auto x_dims = ctx.Input<Tensor>("X")->dims();
PADDLE_ENFORCE_EQ(x_dims.size(), expand_times.size(),
"The number of expandTimes's value must be equal "
"to the rank of X.");
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must be initialized.");
std::vector<int> expand_times =
ctx->Attrs().Get<std::vector<int>>("expandTimes");
auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(static_cast<size_t>(x_dims.size()), expand_times.size(),
"The number of Attr(expandTimes)'s value must be equal "
"to the rank of Input(X).");
PADDLE_ENFORCE_LE(x_dims.size(), 6,
"The rank of X must not be greater than 6.");
"The rank of Input(X) must not be greater than 6.");
std::vector<int64_t> out_shape(x_dims.size());
for (size_t i = 0; i < expand_times.size(); ++i) {
PADDLE_ENFORCE_GE(expand_times[i], 1,
"Each value of expandTimes should not be "
"Each value of Attr(expandTimes) should not be "
"less than 1.");
out_shape[i] = x_dims[i] * expand_times[i];
}
auto* out = ctx.Output<framework::LoDTensor>("Out");
out->Resize(framework::make_ddim(out_shape));
ctx->SetOutputDim("Out", framework::make_ddim(out_shape));
ctx->ShareLoD("X", "Out");
}
};
......@@ -52,20 +54,21 @@ class ExpandOpMaker : public framework::OpProtoAndCheckerMaker {
ExpandOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"The input tensor of expand op."
"The rank of X should be between in 1 and 6.");
"(Tensor, default Tensor<float>) A tensor with rank in [1, 6]."
"X is the input tensor to be expanded.");
AddOutput("Out",
"Output tensor of expand op."
"The rank of Out is same as X except that each dimension size "
"of Out equals to corresponding dimension size of X multiplying "
"corresponding value of expandTimes.");
"(Tensor, default Tensor<float>) A tensor with rank in [1, 6]."
"The rank of Output(Out) is same as Input(X) except that each "
"dimension size of Output(Out) is equal to corresponding "
"dimension size of Input(X) multiplying corresponding value of "
"Attr(expandTimes).");
AddAttr<std::vector<int>>("expandTimes",
"Expand times number for each dimension.");
AddComment(R"DOC(
Expand operator tiles the input by given times number. You should set times
number for each dimension by providing attribute 'expandTimes'. The rank of X
should be between in 1 and 6. Please notice that size of 'expandTimes' must be
same with X's rank.
should be in [1, 6]. Please notice that size of 'expandTimes' must be same with
X's rank.
)DOC");
}
};
......@@ -75,25 +78,27 @@ class ExpandGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext& ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "X must be initialized.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null.");
auto x_dims = ctx.Input<Tensor>("X")->dims();
std::vector<int> expand_times = Attr<std::vector<int>>("expandTimes");
auto out_dims =
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))->dims();
auto* x_grad =
ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null.");
auto x_dims = ctx->GetInputDim("X");
std::vector<int> expand_times =
ctx->Attrs().Get<std::vector<int>>("expandTimes");
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
for (size_t i = 0; i < expand_times.size(); ++i) {
PADDLE_ENFORCE_EQ(x_dims[i] * expand_times[i], out_dims[i],
"Each dimension size of Input(Out@GRAD) should be "
"equal to multiplication of crroresponding dimension "
"size of Input(X) and expandTimes value.");
"size of Input(X) and Attr(expandTimes) value.");
}
if (x_grad) x_grad->Resize(x_dims);
auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
}
}
};
......
......@@ -45,6 +45,8 @@
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
......@@ -53,24 +55,24 @@ template <typename T, size_t D, int MajorType = Eigen::RowMajor,
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
template <typename Place, typename T>
class ExpandKernel : public framework::OpKernel {
class ExpandKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto rank = context.Input<framework::Tensor>("X")->dims().size();
auto rank = context.Input<Tensor>("X")->dims().size();
switch (rank) {
REP_EXPAND_TEMPLATE(6)
default:
PADDLE_ENFORCE(false,
"Only support tensor with rank being between 1 and 6.");
};
}
}
protected:
template <int Rank>
void Expand(const framework::ExecutionContext& context) const {
auto* in0 = context.Input<framework::Tensor>("X");
auto* in0 = context.Input<Tensor>("X");
auto& expand_times = context.Attr<std::vector<int>>("expandTimes");
auto* out0 = context.Output<framework::LoDTensor>("Out");
auto* out0 = context.Output<Tensor>("Out");
Eigen::DSizes<int, Rank> bcast_dims;
auto x_dims = in0->dims();
for (size_t i = 0; i < expand_times.size(); ++i) {
......@@ -85,10 +87,10 @@ class ExpandKernel : public framework::OpKernel {
};
template <typename Place, typename T>
class ExpandGradKernel : public framework::OpKernel {
class ExpandGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in0 = context.Input<framework::Tensor>("X");
auto* in0 = context.Input<Tensor>("X");
auto& expand_times = context.Attr<std::vector<int>>("expandTimes");
auto x_dims = in0->dims();
std::vector<int> reshape_dims_vec;
......@@ -111,23 +113,17 @@ class ExpandGradKernel : public framework::OpKernel {
int dims = reshape_dims_vec.size() * 6 + reduce_dims_vec.size() - 7;
// no need reduce, just copy
if (reduce_dims_vec.size() == 0) {
auto* in0 =
context.Input<framework::LoDTensor>(framework::GradVarName("Out"));
auto* out0 =
context.Output<framework::LoDTensor>(framework::GradVarName("X"));
auto* in0 = context.Input<Tensor>(framework::GradVarName("Out"));
auto* out0 = context.Output<Tensor>(framework::GradVarName("X"));
out0->mutable_data<T>(context.GetPlace());
if (platform::is_cpu_place(context.GetPlace())) {
out0->CopyFrom<T>(*in0, platform::CPUPlace());
} else {
out0->CopyFrom<T>(*in0, platform::GPUPlace());
}
out0->CopyFrom(*in0, context.GetPlace(), context.device_context());
} else {
switch (dims) {
REP_EXPAND_GRAD_TEMPLATE(72)
default:
PADDLE_ENFORCE(
false, "Only support tensor with rank being between 1 and 6.");
};
}
}
}
......@@ -144,11 +140,9 @@ class ExpandGradKernel : public framework::OpKernel {
PADDLE_ENFORCE_EQ(reduce_size, reduce_dims_vec.size(),
"Inconsistent size between template Dims and "
"reduce dimensions.");
auto* in0 =
context.Input<framework::LoDTensor>(framework::GradVarName("Out"));
auto* out0 =
context.Output<framework::LoDTensor>(framework::GradVarName("X"));
auto x = EigenVector<T>::Flatten(*(context.Input<framework::Tensor>("X")));
auto* in0 = context.Input<Tensor>(framework::GradVarName("Out"));
auto* out0 = context.Output<Tensor>(framework::GradVarName("X"));
auto x = EigenVector<T>::Flatten(*(context.Input<Tensor>("X")));
out0->mutable_data<T>(context.GetPlace());
auto x_grad = EigenVector<T>::Flatten(*out0);
Eigen::DSizes<int, Dims / 6 + 1> reshape_dims;
......@@ -165,5 +159,5 @@ class ExpandGradKernel : public framework::OpKernel {
}
};
} // operators
} // paddle
} // namespace operators
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册