未验证 提交 2140e825 编写于 作者: Z zhiboniu 提交者: GitHub

phi_fill_diagonal_tensor (#44649)

* phi_fill_diagonal_tensor

* delete extra lines

* update

* add legacy api test

* rename sig
上级 566c80ff
......@@ -50,12 +50,11 @@ elseif(WITH_MLU)
elseif(WITH_ASCEND_CL)
detection_library(iou_similarity_op SRCS iou_similarity_op.cc
iou_similarity_op_npu.cc)
detection_library(prior_box_op SRCS prior_box_op.cc prior_box_op.cu
prior_box_op_npu.cc)
detection_library(prior_box_op SRCS prior_box_op.cc prior_box_op_npu.cc)
else()
detection_library(iou_similarity_op SRCS iou_similarity_op.cc
iou_similarity_op.cu)
detection_library(prior_box_op SRCS prior_box_op.cc prior_box_op.cu)
detection_library(prior_box_op SRCS prior_box_op.cc)
# detection_library(generate_proposals_v2_op SRCS generate_proposals_v2_op.cc)
endif()
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/detection/prior_box_op.h"
namespace paddle {
namespace operators {
template <typename T>
__device__ inline T clip(T in) {
return min(max(in, 0.), 1.);
}
template <typename T>
__global__ void GenPriorBox(T* out,
const T* aspect_ratios,
const int height,
const int width,
const int im_height,
const int im_width,
const int as_num,
const T offset,
const T step_width,
const T step_height,
const T* min_sizes,
const T* max_sizes,
const int min_num,
bool is_clip,
bool min_max_aspect_ratios_order) {
int num_priors = max_sizes ? as_num * min_num + min_num : as_num * min_num;
int box_num = height * width * num_priors;
CUDA_KERNEL_LOOP(i, box_num) {
int h = i / (num_priors * width);
int w = (i / num_priors) % width;
int p = i % num_priors;
int m = max_sizes ? p / (as_num + 1) : p / as_num;
T cx = (w + offset) * step_width;
T cy = (h + offset) * step_height;
T bw, bh;
T min_size = min_sizes[m];
if (max_sizes) {
int s = p % (as_num + 1);
if (!min_max_aspect_ratios_order) {
if (s < as_num) {
T ar = aspect_ratios[s];
bw = min_size * sqrt(ar) / 2.;
bh = min_size / sqrt(ar) / 2.;
} else {
T max_size = max_sizes[m];
bw = sqrt(min_size * max_size) / 2.;
bh = bw;
}
} else {
if (s == 0) {
bw = bh = min_size / 2.;
} else if (s == 1) {
T max_size = max_sizes[m];
bw = sqrt(min_size * max_size) / 2.;
bh = bw;
} else {
T ar = aspect_ratios[s - 1];
bw = min_size * sqrt(ar) / 2.;
bh = min_size / sqrt(ar) / 2.;
}
}
} else {
int s = p % as_num;
T ar = aspect_ratios[s];
bw = min_size * sqrt(ar) / 2.;
bh = min_size / sqrt(ar) / 2.;
}
T xmin = (cx - bw) / im_width;
T ymin = (cy - bh) / im_height;
T xmax = (cx + bw) / im_width;
T ymax = (cy + bh) / im_height;
out[i * 4] = is_clip ? clip<T>(xmin) : xmin;
out[i * 4 + 1] = is_clip ? clip<T>(ymin) : ymin;
out[i * 4 + 2] = is_clip ? clip<T>(xmax) : xmax;
out[i * 4 + 3] = is_clip ? clip<T>(ymax) : ymax;
}
}
template <typename T>
__global__ void SetVariance(T* out,
const T* var,
const int vnum,
const int num) {
CUDA_KERNEL_LOOP(i, num) { out[i] = var[i % vnum]; }
}
template <typename T>
class PriorBoxOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<paddle::framework::Tensor>("Input");
auto* image = ctx.Input<paddle::framework::Tensor>("Image");
auto* boxes = ctx.Output<paddle::framework::Tensor>("Boxes");
auto* vars = ctx.Output<paddle::framework::Tensor>("Variances");
auto min_sizes = ctx.Attr<std::vector<float>>("min_sizes");
auto max_sizes = ctx.Attr<std::vector<float>>("max_sizes");
auto input_aspect_ratio = ctx.Attr<std::vector<float>>("aspect_ratios");
auto variances = ctx.Attr<std::vector<float>>("variances");
auto flip = ctx.Attr<bool>("flip");
auto clip = ctx.Attr<bool>("clip");
auto min_max_aspect_ratios_order =
ctx.Attr<bool>("min_max_aspect_ratios_order");
std::vector<float> aspect_ratios;
ExpandAspectRatios(input_aspect_ratio, flip, &aspect_ratios);
T step_w = static_cast<T>(ctx.Attr<float>("step_w"));
T step_h = static_cast<T>(ctx.Attr<float>("step_h"));
T offset = static_cast<T>(ctx.Attr<float>("offset"));
auto im_width = image->dims()[3];
auto im_height = image->dims()[2];
auto width = input->dims()[3];
auto height = input->dims()[2];
T step_width, step_height;
if (step_w == 0 || step_h == 0) {
step_width = static_cast<T>(im_width) / width;
step_height = static_cast<T>(im_height) / height;
} else {
step_width = step_w;
step_height = step_h;
}
int num_priors = aspect_ratios.size() * min_sizes.size();
if (max_sizes.size() > 0) {
num_priors += max_sizes.size();
}
int min_num = static_cast<int>(min_sizes.size());
int box_num = width * height * num_priors;
int block = 512;
int grid = (box_num + block - 1) / block;
auto stream = ctx.template device_context<phi::GPUContext>().stream();
boxes->mutable_data<T>(ctx.GetPlace());
vars->mutable_data<T>(ctx.GetPlace());
framework::Tensor r;
framework::TensorFromVector(aspect_ratios, ctx.device_context(), &r);
framework::Tensor min;
framework::TensorFromVector(min_sizes, ctx.device_context(), &min);
T* max_data = nullptr;
framework::Tensor max;
if (max_sizes.size() > 0) {
framework::TensorFromVector(max_sizes, ctx.device_context(), &max);
max_data = max.data<T>();
}
GenPriorBox<T><<<grid, block, 0, stream>>>(boxes->data<T>(),
r.data<T>(),
height,
width,
im_height,
im_width,
aspect_ratios.size(),
offset,
step_width,
step_height,
min.data<T>(),
max_data,
min_num,
clip,
min_max_aspect_ratios_order);
framework::Tensor v;
framework::TensorFromVector(variances, ctx.device_context(), &v);
grid = (box_num * 4 + block - 1) / block;
SetVariance<T><<<grid, block, 0, stream>>>(
vars->data<T>(), v.data<T>(), variances.size(), box_num * 4);
}
}; // namespace operators
} // namespace operators
} // namespace paddle
......@@ -12,64 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/fill_diagonal_tensor_op.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle {
namespace operators {
// calculate the offset\new_dims\(strides of dim1/dim2)\matoffset
void CalMatDims(framework::DDim out_dims,
int dim1,
int dim2,
int64_t *offset,
int64_t *new_dims,
int64_t *strides,
int64_t *matoffset) {
int64_t dimprod = 1, batchdim = 1;
int rank = out_dims.size();
int matoffidx = 0;
for (int i = rank - 1; i >= 0; i--) {
if (i == dim2) {
strides[0] = dimprod;
} else if (i == dim1) {
strides[1] = dimprod;
} else {
batchdim *= out_dims[i];
// matoffset calculate the offset position of the diagonal defined by dim1
// and dim2
// the first circle calculate the final free dimension
// and then calculate the front free dim one by one
if (matoffidx == 0) {
for (int64_t j = 0; j < out_dims[i]; j++) {
matoffset[matoffidx] = dimprod * j;
matoffidx++;
}
} else {
auto size = matoffidx;
for (int64_t j = 1; j < out_dims[i]; j++) {
for (int64_t k = 0; k < size; k++) {
matoffset[matoffidx] = matoffset[k] + dimprod * j;
matoffidx++;
}
}
}
}
dimprod *= out_dims[i];
}
auto diagdim = dim1;
if (*offset >= 0) {
diagdim = std::min(out_dims[dim1], out_dims[dim2] - *offset);
*offset *= strides[0];
} else {
diagdim = std::min(out_dims[dim1] + *offset, out_dims[dim2]);
*offset *= -strides[1];
}
new_dims[0] = batchdim;
new_dims[1] = diagdim;
return;
}
class FillDiagonalTensorOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
......@@ -97,14 +47,6 @@ class FillDiagonalTensorOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *context) const override {
OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "FillDiagonalTensor");
OP_INOUT_CHECK(
context->HasOutput("Out"), "Output", "Out", "FillDiagonalTensor");
auto x_dims = context->GetInputDim("X");
context->SetOutputDim("Out", x_dims);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
......@@ -124,77 +66,10 @@ class FillDiagonalTensorOpVarTypeInference
}
};
template <typename T>
class FillDiagonalTensorKernel : public framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext &ctx) const override {
auto *out = ctx.Output<framework::Tensor>("Out");
auto *srctensor = ctx.Input<framework::Tensor>("Y");
auto dim1 = ctx.Attr<int>("dim1");
auto dim2 = ctx.Attr<int>("dim2");
auto offset = ctx.Attr<int64_t>("offset");
auto *xin = ctx.Input<framework::Tensor>("X");
T *out_data = out->mutable_data<T>(ctx.GetPlace());
const T *fill_data = srctensor->data<T>();
framework::TensorCopy(*xin, ctx.GetPlace(), out);
auto out_dims = out->dims();
auto matdims = srctensor->dims();
auto fill_dims = phi::flatten_to_2d(matdims, matdims.size() - 1);
int64_t new_dims[2], strides[2];
std::vector<int64_t> matdim;
matdim.resize(fill_dims[0]);
CalMatDims(out_dims, dim1, dim2, &offset, new_dims, strides, matdim.data());
PADDLE_ENFORCE_EQ(
new_dims[0],
fill_dims[0],
platform::errors::InvalidArgument("The dims should be %d x %d, but get "
"%d x %d in fill tensor Y",
new_dims[0],
new_dims[1],
fill_dims[0],
fill_dims[1]));
PADDLE_ENFORCE_EQ(
new_dims[1],
fill_dims[1],
platform::errors::InvalidArgument("The dims should be %d x %d, but get "
"%d x %d in fill tensor Y",
new_dims[0],
new_dims[1],
fill_dims[0],
fill_dims[1]));
auto size = out->numel();
for (int64_t i = 0; i < fill_dims[0]; i += 1) {
auto sumoff = matdim[i] + offset;
for (int64_t j = 0; j < fill_dims[1]; j += 1) {
auto fill_index = j * (strides[1] + strides[0]) + sumoff;
if (fill_index < size) {
out_data[fill_index] = fill_data[i * fill_dims[1] + j];
}
}
}
}
};
class FillDiagonalTensorGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")),
"Input",
"Out@GRAD",
"mul");
auto x_dims = ctx->GetInputDim(framework::GradVarName("Out"));
auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
}
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
// Note: don't get data type from ctx.Input<framework::Tensor>("Input");
......@@ -219,50 +94,6 @@ class FillDiagonalTensorGradOpMaker : public framework::SingleGradOpMaker<T> {
}
};
template <typename T>
class FillDiagonalTensorGradKernel : public framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext &ctx) const override {
auto *dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto *dout = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto dim1 = ctx.Attr<int>("dim1");
auto dim2 = ctx.Attr<int>("dim2");
auto offset = ctx.Attr<int64_t>("offset");
auto matrows = 1;
if (dx) {
auto *data = dx->mutable_data<T>(ctx.GetPlace());
auto dx_dims = dx->dims();
for (int i = 0; i < dx_dims.size(); i++) {
if (i != dim1 && i != dim2) {
matrows *= dx_dims[i];
}
}
int64_t new_dims[2], strides[2];
std::vector<int64_t> matdim;
matdim.resize(matrows);
CalMatDims(
dx_dims, dim1, dim2, &offset, new_dims, strides, matdim.data());
auto size = dx->numel();
framework::TensorCopy(*dout, ctx.GetPlace(), dx);
for (int64_t i = 0; i < new_dims[0]; i += 1) {
auto sumoff = matdim[i] + offset;
for (int64_t j = 0; j < new_dims[1]; j += 1) {
auto fill_index = j * (strides[1] + strides[0]) + sumoff;
if (fill_index < size) {
data[fill_index] = 0;
}
}
}
}
}
};
DECLARE_INPLACE_OP_INFERER(FillDiagonalTensorOpInplaceInferer, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(FillDiagonalTensorGradOpInplaceInferer,
{framework::GradVarName("Out"),
......@@ -272,41 +103,25 @@ DECLARE_INPLACE_OP_INFERER(FillDiagonalTensorGradOpInplaceInferer,
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(fill_diagonal_tensor,
FillDiagonalTensorInferShapeFunctor,
PD_INFER_META(phi::FillDiagonalTensorInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(
fill_diagonal_tensor_grad,
FillDiagonalTensorGradInferShapeFunctor,
PD_INFER_META(phi::FillDiagonalTensorGradInferMeta));
REGISTER_OPERATOR(
fill_diagonal_tensor,
ops::FillDiagonalTensorOp,
ops::FillDiagonalTensorOpMaker,
ops::FillDiagonalTensorOpVarTypeInference,
ops::FillDiagonalTensorGradOpMaker<paddle::framework::OpDesc>,
ops::FillDiagonalTensorGradOpMaker<paddle::imperative::OpBase>,
ops::FillDiagonalTensorOpInplaceInferer);
ops::FillDiagonalTensorOpMaker,
ops::FillDiagonalTensorOpInplaceInferer,
ops::FillDiagonalTensorOpVarTypeInference,
FillDiagonalTensorInferShapeFunctor);
REGISTER_OPERATOR(fill_diagonal_tensor_grad,
ops::FillDiagonalTensorGradOp,
ops::FillDiagonalTensorGradOpInplaceInferer);
REGISTER_OP_CPU_KERNEL(
fill_diagonal_tensor,
ops::FillDiagonalTensorKernel<float>,
ops::FillDiagonalTensorKernel<double>,
ops::FillDiagonalTensorKernel<int64_t>,
ops::FillDiagonalTensorKernel<int>,
ops::FillDiagonalTensorKernel<int8_t>,
ops::FillDiagonalTensorKernel<uint8_t>,
ops::FillDiagonalTensorKernel<paddle::platform::float16>,
ops::FillDiagonalTensorKernel<paddle::platform::complex<float>>,
ops::FillDiagonalTensorKernel<paddle::platform::complex<double>>,
ops::FillDiagonalTensorKernel<bool>);
REGISTER_OP_CPU_KERNEL(
fill_diagonal_tensor_grad,
ops::FillDiagonalTensorGradKernel<float>,
ops::FillDiagonalTensorGradKernel<double>,
ops::FillDiagonalTensorGradKernel<int64_t>,
ops::FillDiagonalTensorGradKernel<int>,
ops::FillDiagonalTensorGradKernel<int8_t>,
ops::FillDiagonalTensorGradKernel<uint8_t>,
ops::FillDiagonalTensorGradKernel<paddle::platform::float16>,
ops::FillDiagonalTensorGradKernel<paddle::platform::complex<float>>,
ops::FillDiagonalTensorGradKernel<paddle::platform::complex<double>>,
ops::FillDiagonalTensorGradKernel<bool>);
ops::FillDiagonalTensorGradOpInplaceInferer,
FillDiagonalTensorGradInferShapeFunctor);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/fill_diagonal_tensor_op.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
__global__ void fill_diagonal_tensor_kernel(int64_t size,
T *out_data,
const T *fill_data,
int64_t *strides,
int64_t *matdim,
int64_t offset,
int64_t fill_dims0,
int64_t fill_dims1) {
int64_t i = blockIdx.x;
auto sumoff = matdim[i] + offset;
for (int64_t j = threadIdx.x; j < fill_dims1; j += blockDim.x) {
auto fill_index = j * (strides[1] + strides[0]) + sumoff;
if (fill_index < size) {
out_data[fill_index] = fill_data[i * fill_dims1 + j];
}
}
}
template <typename T>
__global__ void fill_grad_kernel(int64_t size,
T *out_data,
int64_t *strides,
int64_t *matdim,
int64_t offset,
int64_t fill_dims0,
int64_t fill_dims1) {
int64_t i = blockIdx.x;
auto sumoff = matdim[i] + offset;
for (int64_t j = threadIdx.x; j < fill_dims1; j += blockDim.x) {
auto fill_index = j * (strides[1] + strides[0]) + sumoff;
if (fill_index < size) {
out_data[fill_index] = T(0);
}
}
}
template <typename T>
class FillDiagonalTensorCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
#ifdef __HIPCC__
const int64_t kMaxBlockDim = 256;
#else
const int64_t kMaxBlockDim = 512;
#endif
auto *out = ctx.Output<framework::Tensor>("Out");
auto *srctensor = ctx.Input<framework::Tensor>("Y");
auto dim1 = ctx.Attr<int>("dim1");
auto dim2 = ctx.Attr<int>("dim2");
auto offset = ctx.Attr<int64_t>("offset");
auto *xin = ctx.Input<framework::Tensor>("X");
framework::TensorCopy(*xin, ctx.GetPlace(), out);
T *out_data = out->mutable_data<T>(ctx.GetPlace());
const T *fill_data = srctensor->data<T>();
auto out_dims = out->dims();
auto matdims = srctensor->dims();
auto fill_dims = phi::flatten_to_2d(matdims, matdims.size() - 1);
int64_t new_dims[2];
std::vector<int64_t> memory_block;
memory_block.resize(2 + fill_dims[0]);
int64_t *strides = &(memory_block[0]);
int64_t *matdim = &(memory_block[2]);
CalMatDims(out_dims, dim1, dim2, &offset, new_dims, strides, matdim);
PADDLE_ENFORCE_EQ(
new_dims[0],
fill_dims[0],
platform::errors::InvalidArgument("The dims should be %d x %d, but get "
"%d x %d in fill tensor Y",
new_dims[0],
new_dims[1],
fill_dims[0],
fill_dims[1]));
PADDLE_ENFORCE_EQ(
new_dims[1],
fill_dims[1],
platform::errors::InvalidArgument("The dims should be %d x %d, but get "
"%d x %d in fill tensor Y",
new_dims[0],
new_dims[1],
fill_dims[0],
fill_dims[1]));
auto size = out->numel();
auto &dev_ctx = ctx.template device_context<phi::GPUContext>();
auto stream = dev_ctx.stream();
Tensor tensor_tmp;
int64_t *memory_block_cu =
tensor_tmp.mutable_data<int64_t>({2 + fill_dims[0]}, ctx.GetPlace());
const auto gpu_place = ctx.GetPlace();
memory::Copy(gpu_place,
memory_block_cu,
platform::CPUPlace(),
memory_block.data(),
sizeof(int64_t) * (2 + fill_dims[0]),
stream);
int64_t *strides_cu = &memory_block_cu[0], *matdim_cu = &memory_block_cu[2];
auto kGridDim = new_dims[0];
auto kBlockDim = std::min(int64_t(new_dims[1]), kMaxBlockDim);
fill_diagonal_tensor_kernel<T>
<<<kGridDim, kBlockDim, 0, stream>>>(size,
out_data,
fill_data,
strides_cu,
matdim_cu,
offset,
fill_dims[0],
fill_dims[1]);
}
};
template <typename T>
class FillDiagonalTensorGradCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
#ifdef __HIPCC__
const int64_t kMaxBlockDim = 256;
#else
const int64_t kMaxBlockDim = 512;
#endif
auto *dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto *dout = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto dim1 = ctx.Attr<int>("dim1");
auto dim2 = ctx.Attr<int>("dim2");
auto offset = ctx.Attr<int64_t>("offset");
auto matrows = 1;
if (dx) {
auto *data = dx->mutable_data<T>(ctx.GetPlace());
auto dx_dims = dx->dims();
framework::TensorCopy(*dout, ctx.GetPlace(), dx);
for (int i = 0; i < dx_dims.size(); i++) {
if (i != dim1 && i != dim2) {
matrows *= dx_dims[i];
}
}
int64_t new_dims[2];
std::vector<int64_t> memory_block;
memory_block.resize(2 + matrows);
int64_t *strides = &memory_block[0];
int64_t *matdim = &memory_block[2];
CalMatDims(dx_dims, dim1, dim2, &offset, new_dims, strides, matdim);
auto size = dx->numel();
auto &dev_ctx = ctx.template device_context<phi::GPUContext>();
auto stream = dev_ctx.stream();
Tensor tensor_tmp;
int64_t *memory_block_cu =
tensor_tmp.mutable_data<int64_t>({2 + matrows}, ctx.GetPlace());
const auto gpu_place = ctx.GetPlace();
memory::Copy(gpu_place,
memory_block_cu,
platform::CPUPlace(),
memory_block.data(),
sizeof(int64_t) * (2 + matrows),
stream);
int64_t *strides_cu = &memory_block_cu[0],
*matdim_cu = &memory_block_cu[2];
auto kGridDim = new_dims[0];
auto kBlockDim = std::min(int64_t(new_dims[1]), kMaxBlockDim);
fill_grad_kernel<T><<<kGridDim, kBlockDim, 0, stream>>>(
size, data, strides_cu, matdim_cu, offset, new_dims[0], new_dims[1]);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
fill_diagonal_tensor,
ops::FillDiagonalTensorCUDAKernel<float>,
ops::FillDiagonalTensorCUDAKernel<double>,
ops::FillDiagonalTensorCUDAKernel<plat::float16>,
ops::FillDiagonalTensorCUDAKernel<int>,
ops::FillDiagonalTensorCUDAKernel<int64_t>,
ops::FillDiagonalTensorCUDAKernel<int8_t>,
ops::FillDiagonalTensorCUDAKernel<uint8_t>,
ops::FillDiagonalTensorCUDAKernel<paddle::platform::complex<float>>,
ops::FillDiagonalTensorCUDAKernel<paddle::platform::complex<double>>,
ops::FillDiagonalTensorCUDAKernel<bool>);
REGISTER_OP_CUDA_KERNEL(
fill_diagonal_tensor_grad,
ops::FillDiagonalTensorGradCUDAKernel<float>,
ops::FillDiagonalTensorGradCUDAKernel<double>,
ops::FillDiagonalTensorGradCUDAKernel<int>,
ops::FillDiagonalTensorGradCUDAKernel<int64_t>,
ops::FillDiagonalTensorGradCUDAKernel<plat::float16>,
ops::FillDiagonalTensorGradCUDAKernel<int8_t>,
ops::FillDiagonalTensorGradCUDAKernel<uint8_t>,
ops::FillDiagonalTensorGradCUDAKernel<paddle::platform::complex<float>>,
ops::FillDiagonalTensorGradCUDAKernel<paddle::platform::complex<double>>,
ops::FillDiagonalTensorGradCUDAKernel<bool>);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
void CalMatDims(framework::DDim out_dims,
int dim1,
int dim2,
int64_t *offset,
int64_t *new_dims,
int64_t *strides,
int64_t *matoffset);
} // namespace operators
} // namespace paddle
......@@ -874,6 +874,16 @@
inplace : (x -> out)
backward : fill_diagonal_grad
- api : fill_diagonal_tensor
args : (Tensor x, Tensor y, int64_t offset, int dim1, int dim2)
output : Tensor(out)
infer_meta :
func : FillDiagonalTensorInferMeta
kernel :
func : fill_diagonal_tensor
inplace : (x -> out)
backward : fill_diagonal_tensor_grad
- api : flatten
args : (Tensor x, int start_axis, int stop_axis)
output : Tensor(out), Tensor(xshape)
......
......@@ -820,6 +820,15 @@
func : FillDiagonalGradInferMeta
kernel :
func : fill_diagonal_grad
- backward_api : fill_diagonal_tensor_grad
forward : fill_diagonal_tensor (Tensor x, Tensor y, int64_t offset, int dim1, int dim2) -> Tensor(out)
args : (Tensor out_grad, int64_t offset, int dim1, int dim2)
output : Tensor(x_grad)
infer_meta :
func : FillDiagonalTensorGradInferMeta
kernel :
func : fill_diagonal_tensor_grad
inplace : (out_grad -> x_grad)
- backward_api : flatten_grad
......
......@@ -297,6 +297,17 @@ void FillDiagonalGradInferMeta(const MetaTensor& dout,
}
}
void FillDiagonalTensorGradInferMeta(const MetaTensor& out_grad,
int64_t offset,
int dim1,
int dim2,
MetaTensor* x_grad) {
if (x_grad != nullptr) {
x_grad->set_dims(out_grad.dims());
x_grad->set_dtype(out_grad.dtype());
}
}
void GatherNdGradInferMeta(const MetaTensor& x,
const MetaTensor& index,
const MetaTensor& out_grad,
......
......@@ -140,6 +140,12 @@ void EigvalshGradInferMeta(const MetaTensor& out_v,
void FillDiagonalGradInferMeta(
const MetaTensor& dout, float value, int offset, bool wrap, MetaTensor* dx);
void FillDiagonalTensorGradInferMeta(const MetaTensor& out_grad,
int64_t offset,
int dim1,
int dim2,
MetaTensor* x_grad);
void GatherNdGradInferMeta(const MetaTensor& x,
const MetaTensor& index,
const MetaTensor& out_grad,
......
......@@ -1174,6 +1174,19 @@ void ExpandAsInferMeta(const MetaTensor& x,
#undef MAX_RANK_SUPPORTED
}
void FillDiagonalTensorInferMeta(const MetaTensor& x,
const MetaTensor& y,
int64_t offset,
int dim1,
int dim2,
MetaTensor* out) {
PADDLE_ENFORCE_NOT_NULL(out,
phi::errors::InvalidArgument(
"Output Tensor (out) should not be nullptr."));
out->set_dims(x.dims());
out->set_dtype(x.dtype());
}
void GatherInferMeta(const MetaTensor& x,
const MetaTensor& index,
const Scalar& axis,
......
......@@ -182,6 +182,13 @@ void ExpandAsInferMeta(const MetaTensor& x,
const std::vector<int>& target_shape,
MetaTensor* out);
void FillDiagonalTensorInferMeta(const MetaTensor& x,
const MetaTensor& y,
int64_t offset,
int dim1,
int dim2,
MetaTensor* out);
void GatherInferMeta(const MetaTensor& x,
const MetaTensor& index,
const Scalar& axis,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/fill_diagonal_tensor_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void FillDiagonalTensorGradKernel(const Context& ctx,
const DenseTensor& out_grad,
int64_t offset,
int dim1,
int dim2,
DenseTensor* x_grad) {
auto matrows = 1;
if (x_grad) {
auto* data = ctx.template Alloc<T>(x_grad);
auto dx_dims = x_grad->dims();
for (int i = 0; i < dx_dims.size(); i++) {
if (i != dim1 && i != dim2) {
matrows *= dx_dims[i];
}
}
int64_t new_dims[2], strides[2];
std::vector<int64_t> matdim;
matdim.resize(matrows);
CalMatDims(dx_dims, dim1, dim2, &offset, new_dims, strides, matdim.data());
auto size = x_grad->numel();
phi::Copy(ctx, out_grad, ctx.GetPlace(), false, x_grad);
for (int64_t i = 0; i < new_dims[0]; i += 1) {
auto sumoff = matdim[i] + offset;
for (int64_t j = 0; j < new_dims[1]; j += 1) {
auto fill_index = j * (strides[1] + strides[0]) + sumoff;
if (fill_index < size) {
data[fill_index] = 0;
}
}
}
}
}
} // namespace phi
PD_REGISTER_KERNEL(fill_diagonal_tensor_grad,
CPU,
ALL_LAYOUT,
phi::FillDiagonalTensorGradKernel,
float,
double,
int64_t,
int,
int8_t,
uint8_t,
phi::dtype::float16,
phi::dtype::complex<float>,
phi::dtype::complex<double>,
bool) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/fill_diagonal_tensor_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
namespace phi {
void CalMatDims(phi::DDim out_dims,
int dim1,
int dim2,
int64_t *offset,
int64_t *new_dims,
int64_t *strides,
int64_t *matoffset) {
int64_t dimprod = 1, batchdim = 1;
int rank = out_dims.size();
int matoffidx = 0;
for (int i = rank - 1; i >= 0; i--) {
if (i == dim2) {
strides[0] = dimprod;
} else if (i == dim1) {
strides[1] = dimprod;
} else {
batchdim *= out_dims[i];
// matoffset calculate the offset position of the diagonal defined by dim1
// and dim2
// the first circle calculate the final free dimension
// and then calculate the front free dim one by one
if (matoffidx == 0) {
for (int64_t j = 0; j < out_dims[i]; j++) {
matoffset[matoffidx] = dimprod * j;
matoffidx++;
}
} else {
auto size = matoffidx;
for (int64_t j = 1; j < out_dims[i]; j++) {
for (int64_t k = 0; k < size; k++) {
matoffset[matoffidx] = matoffset[k] + dimprod * j;
matoffidx++;
}
}
}
}
dimprod *= out_dims[i];
}
auto diagdim = dim1;
if (*offset >= 0) {
diagdim = std::min(out_dims[dim1], out_dims[dim2] - *offset);
*offset *= strides[0];
} else {
diagdim = std::min(out_dims[dim1] + *offset, out_dims[dim2]);
*offset *= -strides[1];
}
new_dims[0] = batchdim;
new_dims[1] = diagdim;
return;
}
template <typename T, typename Context>
void FillDiagonalTensorKernel(const Context &ctx,
const DenseTensor &x,
const DenseTensor &y,
int64_t offset,
int dim1,
int dim2,
DenseTensor *out) {
T *out_data = ctx.template Alloc<T>(out);
const T *fill_data = y.data<T>();
phi::Copy(ctx, x, ctx.GetPlace(), false, out);
auto out_dims = out->dims();
auto matdims = y.dims();
auto fill_dims = phi::flatten_to_2d(matdims, matdims.size() - 1);
int64_t new_dims[2], strides[2];
std::vector<int64_t> matdim;
matdim.resize(fill_dims[0]);
CalMatDims(out_dims, dim1, dim2, &offset, new_dims, strides, matdim.data());
PADDLE_ENFORCE_EQ(
new_dims[0],
fill_dims[0],
errors::InvalidArgument("The dims should be %d x %d, but get "
"%d x %d in fill tensor Y",
new_dims[0],
new_dims[1],
fill_dims[0],
fill_dims[1]));
PADDLE_ENFORCE_EQ(
new_dims[1],
fill_dims[1],
errors::InvalidArgument("The dims should be %d x %d, but get "
"%d x %d in fill tensor Y",
new_dims[0],
new_dims[1],
fill_dims[0],
fill_dims[1]));
auto size = out->numel();
for (int64_t i = 0; i < fill_dims[0]; i += 1) {
auto sumoff = matdim[i] + offset;
for (int64_t j = 0; j < fill_dims[1]; j += 1) {
auto fill_index = j * (strides[1] + strides[0]) + sumoff;
if (fill_index < size) {
out_data[fill_index] = fill_data[i * fill_dims[1] + j];
}
}
}
}
} // namespace phi
PD_REGISTER_KERNEL(fill_diagonal_tensor,
CPU,
ALL_LAYOUT,
phi::FillDiagonalTensorKernel,
float,
double,
int64_t,
int,
int8_t,
uint8_t,
phi::dtype::float16,
phi::dtype::complex<float>,
phi::dtype::complex<double>,
bool) {}
......@@ -252,7 +252,6 @@ void SliceOneClass(const Context& ctx,
const DenseTensor& items,
const int class_id,
DenseTensor* one_class_item) {
// T* item_data = one_class_item->mutable_data<T>(ctx.GetPlace());
T* item_data = ctx.template Alloc<T>(one_class_item);
const T* items_data = items.data<T>();
const int64_t num_item = items.dims()[0];
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void FillDiagonalTensorGradKernel(const Context &ctx,
const DenseTensor &out_grad,
int64_t offset,
int dim1,
int dim2,
DenseTensor *x_grad);
void CalMatDims(phi::DDim out_dims,
int dim1,
int dim2,
int64_t *offset,
int64_t *new_dims,
int64_t *strides,
int64_t *matoffset);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void FillDiagonalTensorKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
int64_t offset,
int dim1,
int dim2,
DenseTensor* out);
void CalMatDims(phi::DDim out_dims,
int dim1,
int dim2,
int64_t* offset,
int64_t* new_dims,
int64_t* strides,
int64_t* matoffset);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/fill_diagonal_tensor_grad_kernel.h"
#include <algorithm>
#include <vector>
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T>
__global__ void fill_grad_kernel(int64_t size,
T *out_data,
int64_t *strides,
int64_t *matdim,
int64_t offset,
int64_t fill_dims0,
int64_t fill_dims1) {
int64_t i = blockIdx.x;
auto sumoff = matdim[i] + offset;
for (int64_t j = threadIdx.x; j < fill_dims1; j += blockDim.x) {
auto fill_index = j * (strides[1] + strides[0]) + sumoff;
if (fill_index < size) {
out_data[fill_index] = T(0);
}
}
}
template <typename T, typename Context>
void FillDiagonalTensorGradKernel(const Context &ctx,
const DenseTensor &out_grad,
int64_t offset,
int dim1,
int dim2,
DenseTensor *x_grad) {
#ifdef __HIPCC__
const int64_t kMaxBlockDim = 256;
#else
const int64_t kMaxBlockDim = 512;
#endif
auto matrows = 1;
if (x_grad) {
auto *data = ctx.template Alloc<T>(x_grad);
auto dx_dims = x_grad->dims();
phi::Copy(ctx, out_grad, ctx.GetPlace(), false, x_grad);
for (int i = 0; i < dx_dims.size(); i++) {
if (i != dim1 && i != dim2) {
matrows *= dx_dims[i];
}
}
int64_t new_dims[2];
std::vector<int64_t> memory_block;
memory_block.resize(2 + matrows);
int64_t *strides = &memory_block[0];
int64_t *matdim = &memory_block[2];
CalMatDims(dx_dims, dim1, dim2, &offset, new_dims, strides, matdim);
auto size = x_grad->numel();
auto stream = ctx.stream();
DenseTensor tensor_tmp;
tensor_tmp.Resize(phi::make_ddim({2 + matrows}));
int64_t *memory_block_cu = ctx.template Alloc<int64_t>(&tensor_tmp);
const auto gpu_place = ctx.GetPlace();
paddle::memory::Copy(gpu_place,
memory_block_cu,
CPUPlace(),
memory_block.data(),
sizeof(int64_t) * (2 + matrows),
stream);
int64_t *strides_cu = &memory_block_cu[0], *matdim_cu = &memory_block_cu[2];
auto kGridDim = new_dims[0];
auto kBlockDim = std::min(int64_t(new_dims[1]), kMaxBlockDim);
fill_grad_kernel<T><<<kGridDim, kBlockDim, 0, stream>>>(
size, data, strides_cu, matdim_cu, offset, new_dims[0], new_dims[1]);
}
}
} // namespace phi
PD_REGISTER_KERNEL(fill_diagonal_tensor_grad,
GPU,
ALL_LAYOUT,
phi::FillDiagonalTensorGradKernel,
float,
double,
int64_t,
int,
int8_t,
uint8_t,
phi::dtype::float16,
phi::dtype::complex<float>,
phi::dtype::complex<double>,
bool) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/fill_diagonal_tensor_kernel.h"
#include <algorithm>
#include <vector>
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
namespace phi {
template <typename T>
__global__ void fill_diagonal_tensor_kernel(int64_t size,
T *out_data,
const T *fill_data,
int64_t *strides,
int64_t *matdim,
int64_t offset,
int64_t fill_dims0,
int64_t fill_dims1) {
int64_t i = blockIdx.x;
auto sumoff = matdim[i] + offset;
for (int64_t j = threadIdx.x; j < fill_dims1; j += blockDim.x) {
auto fill_index = j * (strides[1] + strides[0]) + sumoff;
if (fill_index < size) {
out_data[fill_index] = fill_data[i * fill_dims1 + j];
}
}
}
template <typename T, typename Context>
void FillDiagonalTensorKernel(const Context &ctx,
const DenseTensor &x,
const DenseTensor &y,
int64_t offset,
int dim1,
int dim2,
DenseTensor *out) {
#ifdef __HIPCC__
const int64_t kMaxBlockDim = 256;
#else
const int64_t kMaxBlockDim = 512;
#endif
phi::Copy(ctx, x, ctx.GetPlace(), false, out);
T *out_data = ctx.template Alloc<T>(out);
const T *fill_data = y.data<T>();
auto out_dims = out->dims();
auto matdims = y.dims();
auto fill_dims = phi::flatten_to_2d(matdims, matdims.size() - 1);
int64_t new_dims[2];
std::vector<int64_t> memory_block;
memory_block.resize(2 + fill_dims[0]);
int64_t *strides = &(memory_block[0]);
int64_t *matdim = &(memory_block[2]);
CalMatDims(out_dims, dim1, dim2, &offset, new_dims, strides, matdim);
PADDLE_ENFORCE_EQ(
new_dims[0],
fill_dims[0],
errors::InvalidArgument("The dims should be %d x %d, but get "
"%d x %d in fill tensor Y",
new_dims[0],
new_dims[1],
fill_dims[0],
fill_dims[1]));
PADDLE_ENFORCE_EQ(
new_dims[1],
fill_dims[1],
errors::InvalidArgument("The dims should be %d x %d, but get "
"%d x %d in fill tensor Y",
new_dims[0],
new_dims[1],
fill_dims[0],
fill_dims[1]));
auto size = out->numel();
auto stream = ctx.stream();
DenseTensor tensor_tmp;
tensor_tmp.Resize(phi::make_ddim({2 + fill_dims[0]}));
int64_t *memory_block_cu = ctx.template Alloc<int64_t>(&tensor_tmp);
const auto gpu_place = ctx.GetPlace();
paddle::memory::Copy(gpu_place,
memory_block_cu,
CPUPlace(),
memory_block.data(),
sizeof(int64_t) * (2 + fill_dims[0]),
stream);
int64_t *strides_cu = &memory_block_cu[0], *matdim_cu = &memory_block_cu[2];
auto kGridDim = new_dims[0];
auto kBlockDim = std::min(int64_t(new_dims[1]), kMaxBlockDim);
fill_diagonal_tensor_kernel<T>
<<<kGridDim, kBlockDim, 0, stream>>>(size,
out_data,
fill_data,
strides_cu,
matdim_cu,
offset,
fill_dims[0],
fill_dims[1]);
}
} // namespace phi
PD_REGISTER_KERNEL(fill_diagonal_tensor,
GPU,
ALL_LAYOUT,
phi::FillDiagonalTensorKernel,
float,
double,
int64_t,
int,
int8_t,
uint8_t,
phi::dtype::float16,
phi::dtype::complex<float>,
phi::dtype::complex<double>,
bool) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature FillDiagonalTensorOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"fill_diagonal_tensor", {"X", "Y"}, {"offset", "dim1", "dim2"}, {"Out"});
}
KernelSignature FillDiagonalTensorGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("fill_diagonal_tensor_grad",
{"Out@GRAD"},
{"offset", "dim1", "dim2"},
{"X@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(fill_diagonal_tensor,
phi::FillDiagonalTensorOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(fill_diagonal_tensor_grad,
phi::FillDiagonalTensorGradOpArgumentMapping);
......@@ -1457,7 +1457,6 @@ class OpTest(unittest.TestCase):
# see details: https://stackoverflow.com/questions/38331703/why-does-numpys-broadcasting-sometimes-allow-comparing-arrays-of-different-leng
if expect_np.size == 0:
self.op_test.assertTrue(actual_np.size == 0) # }}}
# print("actual_np, expect_np", actual_np, expect_np)
self._compare_numpy(name, actual_np, expect_np)
if isinstance(expect, tuple):
self._compare_list(name, actual, expect)
......
......@@ -86,6 +86,7 @@ class TensorFillDiagTensor_Test(OpTest):
def setUp(self):
self.op_type = "fill_diagonal_tensor"
self.python_api = paddle.tensor.manipulation.fill_diagonal_tensor
self.init_kernel_type()
x = np.random.random((10, 10)).astype(self.dtype)
y = np.random.random((10, )).astype(self.dtype)
......@@ -96,22 +97,23 @@ class TensorFillDiagTensor_Test(OpTest):
self.inputs = {"X": x, "Y": y}
self.outputs = {'Out': out}
self.attrs = {"dim1": dim1, "dim2": dim2, "offset": offset}
self.attrs = {"offset": offset, "dim1": dim1, "dim2": dim2}
def init_kernel_type(self):
self.dtype = np.float64
def test_check_output(self):
self.check_output()
self.check_output(check_eager=True)
def test_check_grad(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_eager=True)
class TensorFillDiagTensor_Test2(TensorFillDiagTensor_Test):
def setUp(self):
self.op_type = "fill_diagonal_tensor"
self.python_api = paddle.tensor.manipulation.fill_diagonal_tensor
self.init_kernel_type()
x = np.random.random((2, 20, 25)).astype(self.dtype)
y = np.random.random((2, 20)).astype(self.dtype)
......@@ -122,7 +124,7 @@ class TensorFillDiagTensor_Test2(TensorFillDiagTensor_Test):
self.inputs = {"X": x, "Y": y}
self.outputs = {'Out': out}
self.attrs = {"dim1": dim1, "dim2": dim2, "offset": offset}
self.attrs = {"offset": offset, "dim1": dim1, "dim2": dim2}
def init_kernel_type(self):
self.dtype = np.float32
......@@ -132,6 +134,7 @@ class TensorFillDiagTensor_Test3(TensorFillDiagTensor_Test):
def setUp(self):
self.op_type = "fill_diagonal_tensor"
self.python_api = paddle.tensor.manipulation.fill_diagonal_tensor
self.init_kernel_type()
x = np.random.random((2, 20, 20, 3)).astype(self.dtype)
y = np.random.random((2, 3, 18)).astype(self.dtype)
......@@ -142,11 +145,12 @@ class TensorFillDiagTensor_Test3(TensorFillDiagTensor_Test):
self.inputs = {"X": x, "Y": y}
self.outputs = {'Out': out}
self.attrs = {"dim1": dim1, "dim2": dim2, "offset": offset}
self.attrs = {"offset": offset, "dim1": dim1, "dim2": dim2}
def init_kernel_type(self):
self.dtype = np.float16
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
......@@ -18,6 +18,7 @@ import unittest
import numpy as np
import six
import paddle
from paddle.fluid.framework import _enable_legacy_dygraph
class TensorFillDiagTensor_Test(unittest.TestCase):
......@@ -183,5 +184,9 @@ class TensorFillDiagTensor_Test(unittest.TestCase):
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": False})
class TensorFillDiagTensor_Test_legacy(TensorFillDiagTensor_Test):
_enable_legacy_dygraph()
if __name__ == '__main__':
unittest.main()
......@@ -888,10 +888,17 @@ def _fill_diagonal_tensor_impl(x, y, offset=0, dim1=0, dim2=1, inplace=False):
y = y.reshape([1, -1])
if inplace:
return _C_ops.fill_diagonal_tensor_(x, y, 'dim1', dim1, 'dim2', dim2,
'offset', offset)
return _C_ops.fill_diagonal_tensor(x, y, 'dim1', dim1, 'dim2', dim2,
'offset', offset)
if in_dygraph_mode():
return _C_ops.final_state_fill_diagonal_tensor_(
x, y, offset, dim1, dim2)
else:
return _C_ops.fill_diagonal_tensor_(x, y, 'offset', offset, 'dim1',
dim1, 'dim2', dim2)
if in_dygraph_mode():
return _C_ops.final_state_fill_diagonal_tensor(x, y, offset, dim1, dim2)
else:
return _C_ops.fill_diagonal_tensor(x, y, 'offset', offset, 'dim1', dim1,
'dim2', dim2)
def fill_diagonal_tensor_(x, y, offset=0, dim1=0, dim2=1, name=None):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册