提交 f31217fc 编写于 作者: W wanghaoshuang

Fix issues

上级 ba8a5c15
......@@ -27,10 +27,10 @@ class PadOp : public framework::OperatorWithKernel {
void InferShape(const framework::InferShapeContext &ctx) const override {
auto x_dim = ctx.Input<Tensor>("X")->dims();
auto paddings = Attr<std::vector<int>>("paddings");
PADDLE_ENFORCE_EQ(x_dim.size() * 2, int(paddings.size()),
PADDLE_ENFORCE_EQ(x_dim.size() * 2, int64_t(paddings.size()),
"Size of paddings should be equal to 2 * dimension size "
"of input tensor.");
std::vector<int> out_dims(x_dim.size());
std::vector<int64_t> out_dims(x_dim.size());
for (int i = 0; i < x_dim.size(); ++i) {
out_dims[i] = x_dim[i] + paddings[i * 2] + paddings[i * 2 + 1];
}
......@@ -95,6 +95,7 @@ class PadOpGrad : public framework::OperatorWithKernel {
"Input(Out@GRAD) should not be null");
auto x_dims = ctx.Input<Tensor>("X")->dims();
auto *x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
PADDLE_ENFORCE_NOT_NULL(x_grad, "Output(X@GRAD) should not be null");
x_grad->Resize(x_dims);
}
......
......@@ -28,18 +28,17 @@ using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
template <typename Place, typename T, size_t D>
void PadFunction(const framework::ExecutionContext& context) {
auto pads = context.GetAttr<std::vector<int>>("paddings");
auto pads = context.Attr<std::vector<int>>("paddings");
Eigen::array<std::pair<int, int>, D> paddings;
for (int i = 0; i < paddings.size(); ++i) {
paddings[i].first = pads[i * 2];
paddings[i].second = pads[i * 2 + 1];
}
T pad_value = context.GetAttr<T>("pad_value");
T pad_value = context.Attr<T>("pad_value");
auto* x = context.Input<Tensor>("X");
auto* out = context.Output<Tensor>("Out");
out->mutable_data<T>(context.GetPlace());
auto dims = x->dims();
auto x_tensor = EigenTensor<T, D>::From(*x);
auto out_tensor = EigenTensor<T, D>::From(*out);
......@@ -51,8 +50,8 @@ template <typename Place, typename T>
class PadKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
int dim = context.Input<Tensor>("X")->dims().size();
switch (dim) {
int rank = context.Input<Tensor>("X")->dims().size();
switch (rank) {
case 1:
PadFunction<Place, T, 1>(context);
break;
......@@ -72,14 +71,15 @@ class PadKernel : public framework::OpKernel {
PadFunction<Place, T, 6>(context);
break;
default:
PADDLE_THROW("Only ranks up to 6 supported.");
PADDLE_THROW(
"PadOp only support tensors with no more than 6 dimensions.");
}
}
};
template <typename Place, typename T, size_t D>
void PadGradFunction(const framework::ExecutionContext& context) {
auto pads = context.GetAttr<std::vector<int>>("paddings");
auto pads = context.Attr<std::vector<int>>("paddings");
Eigen::array<std::pair<int, int>, D> paddings;
for (int i = 0; i < paddings.size(); ++i) {
paddings[i].first = -pads[i * 2];
......@@ -99,9 +99,9 @@ template <typename Place, typename T>
class PadGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
size_t dim =
size_t rank =
context.Input<Tensor>(framework::GradVarName("Out"))->dims().size();
switch (dim) {
switch (rank) {
case 1:
PadGradFunction<Place, T, 1>(context);
break;
......@@ -121,7 +121,8 @@ class PadGradKernel : public framework::OpKernel {
PadGradFunction<Place, T, 6>(context);
break;
default:
PADDLE_THROW("Only ranks up to 6 supported.");
PADDLE_THROW(
"PadOp only support tensors with no more than 6 dimensions.");
}
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册