提交 8b6fda6f 编写于 作者: W wanghaoshuang

move stride function to ddim.h

上级 bc632df8
...@@ -292,5 +292,13 @@ DDim flatten_to_2d(const DDim& src, int num_col_dims) { ...@@ -292,5 +292,13 @@ DDim flatten_to_2d(const DDim& src, int num_col_dims) {
DDim flatten_to_1d(const DDim& src) { return make_ddim({product(src)}); } DDim flatten_to_1d(const DDim& src) { return make_ddim({product(src)}); }
DDim stride(const DDim& ddim) {
std::vector<int64_t> strides(ddim.size());
strides[ddim.size() - 1] = 1;
for (int i = ddim.size() - 2; i >= 0; --i) {
strides[i] = strides[i + 1] * ddim[i + 1];
}
return framework::make_ddim(strides);
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -121,6 +121,7 @@ DDim flatten_to_2d(const DDim& src, int num_col_dims); ...@@ -121,6 +121,7 @@ DDim flatten_to_2d(const DDim& src, int num_col_dims);
DDim flatten_to_1d(const DDim& src); DDim flatten_to_1d(const DDim& src);
DDim stride(const DDim& ddim);
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
......
...@@ -32,8 +32,9 @@ class CropOp : public framework::OperatorWithKernel { ...@@ -32,8 +32,9 @@ class CropOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"),
"Output(Out) of CropOp should not be null."); "Output(Out) of CropOp should not be null.");
auto x_dim = ctx.Input<LoDTensor>("X")->dims(); auto x_dim = ctx.Input<LoDTensor>("X")->dims();
auto Y = ctx.Input<LoDTensor>("Y"); auto *y = ctx.Input<LoDTensor>("Y");
if (Y == nullptr) { auto *out = ctx.Output<LoDTensor>("Out");
if (y == nullptr) {
auto shape = Attr<std::vector<int>>("shape"); auto shape = Attr<std::vector<int>>("shape");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
int64_t(shape.size()), x_dim.size(), int64_t(shape.size()), x_dim.size(),
...@@ -42,12 +43,12 @@ class CropOp : public framework::OperatorWithKernel { ...@@ -42,12 +43,12 @@ class CropOp : public framework::OperatorWithKernel {
for (size_t i = 0; i < shape.size(); ++i) { for (size_t i = 0; i < shape.size(); ++i) {
tensor_shape[i] = static_cast<int64_t>(shape[i]); tensor_shape[i] = static_cast<int64_t>(shape[i]);
} }
ctx.Output<LoDTensor>("Out")->Resize(framework::make_ddim(tensor_shape)); out->Resize(framework::make_ddim(tensor_shape));
} else { } else {
PADDLE_ENFORCE_EQ(framework::arity(x_dim), framework::arity(Y->dims()), PADDLE_ENFORCE_EQ(framework::arity(x_dim), framework::arity(y->dims()),
"Tensor rank of both CropOp's " "Tensor rank of both CropOp's "
"inputs must be same."); "inputs must be same.");
ctx.Output<LoDTensor>("Out")->Resize(Y->dims()); out->Resize(y->dims());
} }
} }
}; };
......
...@@ -24,19 +24,7 @@ namespace operators { // Internal ...@@ -24,19 +24,7 @@ namespace operators { // Internal
template <typename T, size_t D, int MajorType = Eigen::RowMajor, template <typename T, size_t D, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>; using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
using framework::Tensor; using framework::Tensor;
using framework::DDim;
// TODO(wanghaoshuang): move this function to other place
DDim stride(const DDim& ddim) {
std::vector<int64_t> strides(ddim.size());
strides[ddim.size() - 1] = 1;
for (int i = ddim.size() - 2; i >= 0; --i) {
strides[i] = strides[i + 1] * ddim[i + 1];
}
return make_ddim(strides);
}
template <typename T> template <typename T>
class CropKernel : public framework::OpKernel { class CropKernel : public framework::OpKernel {
...@@ -44,13 +32,13 @@ class CropKernel : public framework::OpKernel { ...@@ -44,13 +32,13 @@ class CropKernel : public framework::OpKernel {
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<Tensor>("X"); auto* x = context.Input<Tensor>("X");
auto* out = context.Output<Tensor>("Out"); auto* out = context.Output<Tensor>("Out");
T* x_data = x->data<T>(); const T* x_data = x->data<T>();
T* out_data = out->mutable_data<T>(context.GetPlace()); T* out_data = out->mutable_data<T>(context.GetPlace());
auto x_stride = stride(x->dims()); auto x_stride = framework::stride(x->dims());
auto out_stride = stride(out->dims()); auto out_stride = framework::stride(out->dims());
auto offsets = context.Attr<std::vector<int>>("offsets"); auto offsets = context.Attr<std::vector<int>>("offsets");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
x_dims.size(), offsets.size(), x->dims().size(), offsets.size(),
"Offsets size should be equal to dimension size of input tensor."); "Offsets size should be equal to dimension size of input tensor.");
int64_t offset = 0; int64_t offset = 0;
for (int i = 0; i < offsets.size(); ++i) { for (int i = 0; i < offsets.size(); ++i) {
...@@ -71,7 +59,7 @@ void CropGradFunction(const framework::ExecutionContext& context) { ...@@ -71,7 +59,7 @@ void CropGradFunction(const framework::ExecutionContext& context) {
Eigen::array<std::pair<int, int>, D> paddings; Eigen::array<std::pair<int, int>, D> paddings;
for (int i = 0; i < D; ++i) { for (int i = 0; i < D; ++i) {
paddings[i].first = offsets[i]; paddings[i].first = offsets[i];
paddings[i].second = d_x_dims[i] - d_out_dims[i] - offsets[i]; paddings[i].second = d_x->dims()[i] - d_out->dims()[i] - offsets[i];
} }
auto d_x_tensor = EigenTensor<T, D>::From(*d_x); auto d_x_tensor = EigenTensor<T, D>::From(*d_x);
auto d_out_tensor = EigenTensor<T, D>::From(*d_out); auto d_out_tensor = EigenTensor<T, D>::From(*d_out);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册