未验证 提交 2e4196f6 编写于 作者: 1 123malin 提交者: GitHub

add new api for Paddle2.0: nonzero, index_selct, roll, cross (#23176)

上级 f11af6a9
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/cross_op.h"
#include <memory>
namespace paddle {
namespace operators {
using framework::Tensor;
using framework::DDim;
class CrossOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::InvalidArgument(
"Input(X) of CrossOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Y"), true,
platform::errors::InvalidArgument(
"Input(Index) of CrossOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::InvalidArgument(
"Output(Out) of CrossOp should not be null."));
auto x_dim = ctx->GetInputDim("X");
auto y_dim = ctx->GetInputDim("Y");
auto dim = ctx->Attrs().Get<int>("dim");
bool dims_match = CheckDims(x_dim, y_dim);
PADDLE_ENFORCE_EQ(dims_match, true,
platform::errors::InvalidArgument(
"The 'shape' of Input(X) should be equal to "
"the 'shape' of Input(Y). But received "
"Input(X).dimensions = [%s], "
"Input(Y).dimensions = [%s]",
x_dim, y_dim));
if (dim != kDefaultDim) {
PADDLE_ENFORCE_EQ(
dim < x_dim.size() && dim >= (0 - x_dim.size()), true,
platform::errors::OutOfRange(
"Attr(dim) is out of range, It's expected "
"to be in range of [-%d, %d]. But received Attr(dim) = %d.",
x_dim.size(), x_dim.size() - 1, dim));
if (dim < 0) {
dim += x_dim.size();
}
PADDLE_ENFORCE_EQ(x_dim[dim] == 3 && y_dim[dim] == 3, true,
platform::errors::InvalidArgument(
"Input(X/Y).dims()[dim] should be equal to 3."
"But received Input(X/Y).dims()[dim] = %d.",
x_dim[dim]));
}
ctx->SetOutputDim("Out", x_dim);
auto type = ctx->GetInputsVarType("X")[0];
if (type == framework::proto::VarType::LOD_TENSOR) {
ctx->ShareLoD("X", /*->*/ "Out");
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(data_type, ctx.device_context());
}
};
class CrossGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("X"), true,
platform::errors::InvalidArgument("Input(X) should be not null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("Y"), true,
platform::errors::InvalidArgument("Input(Y) should be not null."));
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
platform::errors::InvalidArgument(
"Input(Out@GRAD) should be not null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("X")), true,
platform::errors::InvalidArgument(
"Output(X@GRAD) should be not null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("Y")), true,
platform::errors::InvalidArgument(
"Output(Y@GRAD) should be not null."));
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
ctx->SetOutputDim(framework::GradVarName("Y"), ctx->GetInputDim("Y"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context());
}
};
class CrossOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor) the input tensor.");
AddInput("Y", "(Tensor) the second input tensor.");
AddOutput("Out", "(Tensor), the output tensor.");
AddAttr<int>("dim", "the dimension to take the cross-product in.")
.SetDefault(kDefaultDim);
AddComment(R"DOC(
Returns the cross product of vectors in dimension dim of
input and other. Input and other must have the same size,
and the size of their dim dimension should be 3.
If dim is not given, it defaults to the first dimension
found with the size 3.
)DOC");
}
};
template <typename T>
class CrossGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("cross_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("Y", this->Input("Y"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
op->SetAttrMap(this->Attrs());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(cross, ops::CrossOp, ops::CrossOpMaker,
ops::CrossGradMaker<paddle::framework::OpDesc>,
ops::CrossGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(cross_grad, ops::CrossGradOp);
REGISTER_OP_CPU_KERNEL(
cross, ops::CrossKernel<paddle::platform::CPUDeviceContext, float>,
ops::CrossKernel<paddle::platform::CPUDeviceContext, double>,
ops::CrossKernel<paddle::platform::CPUDeviceContext, int>,
ops::CrossKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
cross_grad, ops::CrossGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::CrossGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::CrossGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::CrossGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/cross_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
cross, ops::CrossKernel<paddle::platform::CUDADeviceContext, float>,
ops::CrossKernel<paddle::platform::CUDADeviceContext, double>,
ops::CrossKernel<paddle::platform::CUDADeviceContext, int>,
ops::CrossKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
cross_grad,
ops::CrossGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::CrossGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::CrossGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::CrossGradKernel<paddle::platform::CUDADeviceContext, int64_t>);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using DDim = framework::DDim;
const int kDefaultDim = framework::DDim::kMaxRank;
inline bool CheckDims(const DDim& dims_x, const DDim& dims_y) {
if (dims_x.size() != dims_y.size()) {
return false;
}
for (int i = 0; i < dims_x.size(); i++) {
if (dims_x[i] != dims_y[i]) {
return false;
}
}
return true;
}
template <typename DeviceContext, typename T>
class CrossKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* input_x_var = context.InputVar("X");
auto* input_y_var = context.InputVar("Y");
auto* output_var = context.OutputVar("Out");
auto& input_x = input_x_var->Get<LoDTensor>();
auto& input_y = input_y_var->Get<LoDTensor>();
auto* output = output_var->GetMutable<LoDTensor>();
int dim = context.Attr<int>("dim");
auto input_x_dims = input_x.dims();
auto input_y_dims = input_y.dims();
bool dims_match = CheckDims(input_x_dims, input_y_dims);
PADDLE_ENFORCE_EQ(dims_match, true,
platform::errors::InvalidArgument(
"The 'shape' of Input(X) should be equal to "
"the 'shape' of Input(Y). But received "
"Input(X).dimensions = [%s], "
"Input(Y).dimensions = [%s]",
input_x_dims, input_x_dims));
if (dim != kDefaultDim) {
PADDLE_ENFORCE_EQ(
dim < input_x_dims.size() && dim >= (0 - input_x_dims.size()), true,
platform::errors::OutOfRange(
"Attr(dim) is out of range, It's expected "
"to be in range of [-%d, %d]. But received Attr(dim) = %d.",
input_x_dims.size(), input_x_dims.size() - 1, dim));
if (dim < 0) {
dim += input_x_dims.size();
}
PADDLE_ENFORCE_EQ(
input_x_dims[dim] == 3, true,
platform::errors::InvalidArgument(
"Input(X/Y).dims[dim] must be equal to 3. But received: "
"Input(X/Y).dims[dim] = [%d].",
input_x_dims[dim]));
} else {
for (auto i = 0; i < input_x_dims.size(); i++) {
if (input_x_dims[i] == 3) {
dim = i;
break;
}
}
PADDLE_ENFORCE_EQ(dim == kDefaultDim, false,
platform::errors::InvalidArgument(
"There must be at least one dimension 'd' so that "
"Input(X/Y).dims()[d] is equal to 3. "
"But received: Input(X/Y).dims() == [%s].",
input_x_dims));
}
auto outer_loops = 1;
for (auto i = 0; i < dim; i++) {
outer_loops *= input_x_dims[i];
}
auto slice_size = 1;
for (auto i = dim + 1; i < input_x_dims.size(); i++) {
slice_size *= input_x_dims[i];
}
std::vector<T> input_x_vec, input_y_vec;
framework::TensorToVector(input_x, context.device_context(), &input_x_vec);
framework::TensorToVector(input_y, context.device_context(), &input_y_vec);
std::vector<T> out_vec(output->numel());
output->mutable_data<T>(context.GetPlace());
for (auto i = 0; i < outer_loops; i++) {
for (auto j = 0; j < 3; j++) {
auto dst_pos = (3 * i + j) * slice_size;
auto in_pos1 = (3 * i + ((j + 1) % 3)) * slice_size;
auto in_pos2 = (3 * i + ((j + 2) % 3)) * slice_size;
for (auto k = 0; k < slice_size; k++) {
out_vec[dst_pos + k] =
input_x_vec[in_pos1 + k] * input_y_vec[in_pos2 + k] -
input_x_vec[in_pos2 + k] * input_y_vec[in_pos1 + k];
}
}
}
framework::TensorFromVector(out_vec, context.device_context(), output);
output->Resize(input_x_dims);
}
};
template <typename DeviceContext, typename T>
class CrossGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* input_x_var = context.InputVar("X");
auto* input_y_var = context.InputVar("Y");
auto* input_out_grad_var = context.InputVar(framework::GradVarName("Out"));
auto* output_x_grad_var = context.OutputVar(framework::GradVarName("X"));
auto* output_y_grad_var = context.OutputVar(framework::GradVarName("Y"));
auto& input_x = input_x_var->Get<LoDTensor>();
auto& input_y = input_y_var->Get<LoDTensor>();
auto& input_out_grad = input_out_grad_var->Get<LoDTensor>();
auto* output_x_grad = output_x_grad_var->GetMutable<LoDTensor>();
auto* output_y_grad = output_y_grad_var->GetMutable<LoDTensor>();
int dim = context.Attr<int>("dim");
auto input_x_dims = input_x.dims();
if (dim != kDefaultDim) {
PADDLE_ENFORCE_EQ(
dim < input_x_dims.size() && dim >= (0 - input_x_dims.size()), true,
platform::errors::OutOfRange(
"Attr(dim) is out of range, It's expected "
"to be in range of [-%d, %d]. But received Attr(dim) = %d.",
input_x_dims.size(), input_x_dims.size() - 1, dim));
if (dim < 0) {
dim += input_x_dims.size();
}
PADDLE_ENFORCE_EQ(
input_x_dims[dim] == 3, true,
platform::errors::InvalidArgument(
"Input(X/Y).dims[dim] must be equal to 3. But received: "
"Input(X/Y).dims[dim] = [%d].",
input_x_dims[dim]));
} else {
for (auto i = 0; i < input_x_dims.size(); i++) {
if (input_x_dims[i] == 3) {
dim = i;
break;
}
}
PADDLE_ENFORCE_EQ(dim == kDefaultDim, false,
platform::errors::InvalidArgument(
"There must be at least one dimension 'd' "
"so that Input(X/Y).dims()[d] is equal to 3. "
"But received: Input(X/Y).dims() == [%s].",
input_x_dims));
}
auto outer_loops = 1;
for (auto i = 0; i < dim; i++) {
outer_loops *= input_x_dims[i];
}
auto slice_size = 1;
for (auto i = dim + 1; i < input_x_dims.size(); i++) {
slice_size *= input_x_dims[i];
}
std::vector<T> input_x_vec, input_y_vec, input_dout_vec;
framework::TensorToVector(input_x, context.device_context(), &input_x_vec);
framework::TensorToVector(input_y, context.device_context(), &input_y_vec);
framework::TensorToVector(input_out_grad, context.device_context(),
&input_dout_vec);
std::vector<T> out_dx_vec(output_x_grad->numel());
std::vector<T> out_dy_vec(output_y_grad->numel());
output_x_grad->mutable_data<T>(context.GetPlace());
output_y_grad->mutable_data<T>(context.GetPlace());
for (auto i = 0; i < outer_loops; i++) {
for (auto j = 0; j < 3; j++) {
auto dst_pos = (3 * i + j) * slice_size;
auto in_pos1 = (3 * i + ((j + 1) % 3)) * slice_size;
auto in_pos2 = (3 * i + ((j + 2) % 3)) * slice_size;
for (auto k = 0; k < slice_size; k++) {
out_dx_vec[dst_pos + k] =
input_dout_vec[in_pos2 + k] * input_y_vec[in_pos1 + k] -
input_dout_vec[in_pos1 + k] * input_y_vec[in_pos2 + k];
out_dy_vec[dst_pos + k] =
input_dout_vec[in_pos1 + k] * input_x_vec[in_pos2 + k] -
input_dout_vec[in_pos2 + k] * input_x_vec[in_pos1 + k];
}
}
}
framework::TensorFromVector(out_dx_vec, context.device_context(),
output_x_grad);
framework::TensorFromVector(out_dy_vec, context.device_context(),
output_y_grad);
output_x_grad->Resize(input_x_dims);
output_y_grad->Resize(input_x_dims);
}
};
} // namespace operators
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/index_select_op.h"
#include <memory>
namespace paddle {
namespace operators {
using framework::Tensor;
class IndexSelectOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::InvalidArgument(
"Input(X) of IndexSelectOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Index"), true,
platform::errors::InvalidArgument(
"Input(Index) of IndexSelectOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::InvalidArgument(
"Output(Out) of IndexSelectOp should not be null."));
auto input_dim = ctx->GetInputDim("X");
auto index_dim = ctx->GetInputDim("Index");
auto dim = ctx->Attrs().Get<int>("dim");
PADDLE_ENFORCE_EQ(
dim < input_dim.size() && dim >= (0 - input_dim.size()), true,
platform::errors::OutOfRange(
"Attr(dim) is out of range, It's expected "
"to be in range of [-%d, %d]. But received Attr(dim) = %d.",
input_dim.size(), input_dim.size() - 1, dim));
PADDLE_ENFORCE_EQ(
index_dim.size() == 1 || (index_dim.size() == 2 && index_dim[1] == 1),
true, platform::errors::InvalidArgument(
"The 'shape' of Input(Index) must be 1-D tensor. "
"But received: the 'shape' of Input(Index) is [%s], "
"the dimension of Input(Index) is [%d].",
index_dim, index_dim.size()));
auto output_dim = framework::vectorize(input_dim);
if (dim < 0) {
dim += input_dim.size();
}
output_dim[dim] = index_dim[0];
ctx->SetOutputDim("Out", framework::make_ddim(output_dim));
auto type = ctx->GetInputsVarType("X")[0];
if (type == framework::proto::VarType::LOD_TENSOR) {
ctx->ShareLoD("X", /*->*/ "Out");
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(data_type, ctx.device_context());
}
};
class IndexSelectGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("Index"), true,
platform::errors::InvalidArgument("Input(Index) should be not null."));
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
platform::errors::InvalidArgument(
"Input(Out@GRAD) should be not null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("X")), true,
platform::errors::InvalidArgument(
"Output(X@GRAD) should be not null."));
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context());
}
};
class IndexSelectOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor) the input tensor.");
AddInput("Index", "the 1-D tensor containing the indices to index.");
AddOutput("Out", "the output tensor.");
AddAttr<int>("dim", "the dimension in which we index.").SetDefault(0);
AddComment(R"DOC(
Returns a new tensor which indexes the input tensor
along dimension dim using the entries in index which
is a Tensor.
The returned tensor has the same number of dimensions
as the original tensor (input). The dim-th dimension
has the same size as the length of index; other dimensions
have the same size as in the original tensor.
)DOC");
}
};
template <typename T>
class IndexSelectGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("index_select_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("Index", this->Input("Index"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(IndexSelectGradNoNeedBufferVarsInference,
"X");
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(index_select, ops::IndexSelectOp, ops::IndexSelectOpMaker,
ops::IndexSelectGradMaker<paddle::framework::OpDesc>,
ops::IndexSelectGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(index_select_grad, ops::IndexSelectGradOp,
ops::IndexSelectGradNoNeedBufferVarsInference);
REGISTER_OP_CPU_KERNEL(
index_select,
ops::IndexSelectKernel<paddle::platform::CPUDeviceContext, float>,
ops::IndexSelectKernel<paddle::platform::CPUDeviceContext, double>,
ops::IndexSelectKernel<paddle::platform::CPUDeviceContext, int>,
ops::IndexSelectKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
index_select_grad,
ops::IndexSelectGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::IndexSelectGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::IndexSelectGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::IndexSelectGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/index_select_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
index_select,
ops::IndexSelectKernel<paddle::platform::CUDADeviceContext, float>,
ops::IndexSelectKernel<paddle::platform::CUDADeviceContext, double>,
ops::IndexSelectKernel<paddle::platform::CUDADeviceContext, int>,
ops::IndexSelectKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
index_select_grad,
ops::IndexSelectGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::IndexSelectGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::IndexSelectGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::IndexSelectGradKernel<paddle::platform::CUDADeviceContext, int64_t>);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using DDim = framework::DDim;
template <typename T, typename IndexT = int>
void IndexSelectInner(const framework::ExecutionContext& context,
const LoDTensor& input, const LoDTensor& index,
LoDTensor* output, int dim) {
auto input_dim = input.dims();
auto input_dim_size = input_dim.size();
auto output_dim = output->dims();
auto slice_size = 1;
for (auto i = dim + 1; i < input_dim_size; i++) {
slice_size *= input_dim[i];
}
auto input_width = slice_size * input_dim[dim];
auto output_width = slice_size * output_dim[dim];
auto outer_nums = 1;
for (auto i = 0; i < dim; i++) {
outer_nums *= input_dim[i];
}
auto index_size = index.dims()[0];
std::vector<T> input_vec;
std::vector<IndexT> index_vec;
TensorToVector(input, context.device_context(), &input_vec);
TensorToVector(index, context.device_context(), &index_vec);
std::vector<T> out_vec(output->numel());
VLOG(3) << "Index_Select_Debug; outer_nums: " << outer_nums
<< "; slice_size: " << slice_size << "; input_width: " << input_width
<< "; output_width: " << output_width
<< "; index_size: " << index_size;
for (auto i = 0; i < outer_nums; i++) {
auto input_start_offset = i * input_width;
auto output_start_offset = i * output_width;
for (auto j = 0; j < index_size; j++) {
IndexT index_value = index_vec[j];
for (auto k = 0; k < slice_size; k++) {
out_vec[output_start_offset + j * slice_size + k] =
input_vec[input_start_offset + index_value * slice_size + k];
}
}
}
output->mutable_data<T>(context.GetPlace());
framework::TensorFromVector(out_vec, context.device_context(), output);
output->Resize(output_dim);
}
template <typename DeviceContext, typename T>
class IndexSelectKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* inputs_var = context.InputVar("X");
auto* index_var = context.InputVar("Index");
auto* output_var = context.OutputVar("Out");
auto& inputs = inputs_var->Get<LoDTensor>();
auto& index = index_var->Get<LoDTensor>();
auto* output = output_var->GetMutable<framework::LoDTensor>();
int dim = context.Attr<int>("dim");
if (dim < 0) {
dim += inputs.dims().size();
}
const auto& index_type = index.type();
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ(index_type_match, true,
platform::errors::InvalidArgument(
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (index_type == framework::proto::VarType::INT32) {
IndexSelectInner<T, int>(context, inputs, index, output, dim);
} else if (index_type == framework::proto::VarType::INT64) {
IndexSelectInner<T, int64_t>(context, inputs, index, output, dim);
}
}
};
template <typename T, typename IndexT = int>
void IndexSelectGradInner(const framework::ExecutionContext& context,
const LoDTensor& out_grad, const LoDTensor& index,
LoDTensor* x_grad, int dim) {
std::vector<T> input_vec;
std::vector<IndexT> index_vec;
TensorToVector(out_grad, context.device_context(), &input_vec);
TensorToVector(index, context.device_context(), &index_vec);
auto input_dim = out_grad.dims();
auto input_dim_size = input_dim.size();
auto output_dim = x_grad->dims();
std::vector<T> out_vec(x_grad->numel(), 0);
auto slice_size = 1;
for (auto i = dim + 1; i < input_dim_size; i++) {
slice_size *= input_dim[i];
}
auto input_width = slice_size * input_dim[dim];
auto output_width = slice_size * output_dim[dim];
auto outer_nums = 1;
for (auto i = 0; i < dim; i++) {
outer_nums *= input_dim[i];
}
auto index_size = index.dims()[0];
VLOG(3) << "Index_Select_Grad_Debug; outer_nums: " << outer_nums
<< "; slice_size: " << slice_size << "; input_width: " << input_width
<< "; output_width: " << output_width
<< "; index_size: " << index_size;
for (auto i = 0; i < outer_nums; i++) {
auto input_start_offset = i * input_width;
auto output_start_offset = i * output_width;
for (auto j = 0; j < index_size; j++) {
IndexT index_value = index_vec[j];
for (auto k = 0; k < slice_size; k++) {
out_vec[output_start_offset + index_value * slice_size + k] +=
input_vec[input_start_offset + j * slice_size + k];
}
}
}
x_grad->mutable_data<T>(context.GetPlace());
framework::TensorFromVector(out_vec, context.device_context(), x_grad);
x_grad->Resize(output_dim);
}
template <typename DeviceContext, typename T>
class IndexSelectGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* index_var = context.InputVar("Index");
auto* x_grad_var = context.OutputVar(framework::GradVarName("X"));
auto* out_grad_var = context.InputVar(framework::GradVarName("Out"));
auto& index = index_var->Get<LoDTensor>();
auto& out_grad = out_grad_var->Get<LoDTensor>();
auto* x_grad = x_grad_var->GetMutable<framework::LoDTensor>();
int dim = context.Attr<int>("dim");
if (dim < 0) {
dim += out_grad.dims().size();
}
const auto& index_type = index.type();
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ(index_type_match, true,
platform::errors::InvalidArgument(
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (index_type == framework::proto::VarType::INT32) {
IndexSelectGradInner<T, int>(context, out_grad, index, x_grad, dim);
} else if (index_type == framework::proto::VarType::INT64) {
IndexSelectGradInner<T, int64_t>(context, out_grad, index, x_grad, dim);
}
}
};
} // namespace operators
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/roll_op.h"
#include <memory>
#include <vector>
namespace paddle {
namespace operators {
using framework::Tensor;
class RollOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::InvalidArgument(
"Input(X) of RollOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::InvalidArgument(
"Output(Out) of RollOp should not be null."));
auto dims = ctx->Attrs().Get<std::vector<int64_t>>("dims");
auto shifts = ctx->Attrs().Get<std::vector<int64_t>>("shifts");
PADDLE_ENFORCE_EQ(dims.size(), shifts.size(),
platform::errors::InvalidArgument(
"Attr(dims).size() should be equl to "
"Attr(shifts).size(). But received "
"Attr(dims).size() = %d, Attr(shifts).size() = %d",
dims.size(), shifts.size()));
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
auto type = ctx->GetInputsVarType("X")[0];
if (type == framework::proto::VarType::LOD_TENSOR) {
ctx->ShareLoD("X", /*->*/ "Out");
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(data_type, ctx.device_context());
}
};
class RollGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
platform::errors::InvalidArgument(
"Input(Out@GRAD) should be not null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("X")), true,
platform::errors::InvalidArgument(
"Output(X@GRAD) should be not null."));
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context());
}
};
class RollOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor) the input tensor.");
AddOutput("Out", "(Tensor), the output tensor.");
AddAttr<std::vector<int64_t>>("shifts",
"The number of places by which the elements "
"of the tensor are shifted.")
.SetDefault({});
AddAttr<std::vector<int64_t>>(
"dims",
"Axis along which to roll. It must have the same size "
"with shifts.")
.SetDefault({});
AddComment(R"DOC(
Roll the tensor along the given dimension(s).
Elements that are shifted beyond the last position
are re-introduced at the first position. If a dimension
is not specified, the tensor will be flattened before
rolling and then restored to the original shape.
)DOC");
}
};
template <typename T>
class RollGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("roll_grad");
op->SetInput("X", this->Input("X"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(RollGradNoNeedBufferVarsInference, "X");
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(roll, ops::RollOp, ops::RollOpMaker,
ops::RollGradMaker<paddle::framework::OpDesc>,
ops::RollGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(roll_grad, ops::RollGradOp,
ops::RollGradNoNeedBufferVarsInference);
REGISTER_OP_CPU_KERNEL(
roll, ops::RollKernel<paddle::platform::CPUDeviceContext, float>,
ops::RollKernel<paddle::platform::CPUDeviceContext, double>,
ops::RollKernel<paddle::platform::CPUDeviceContext, int>,
ops::RollKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
roll_grad, ops::RollGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::RollGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::RollGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::RollGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/roll_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
roll, ops::RollKernel<paddle::platform::CUDADeviceContext, float>,
ops::RollKernel<paddle::platform::CUDADeviceContext, double>,
ops::RollKernel<paddle::platform::CUDADeviceContext, int>,
ops::RollKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
roll_grad, ops::RollGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::RollGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::RollGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::RollGradKernel<paddle::platform::CUDADeviceContext, int64_t>);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using DDim = framework::DDim;
template <typename T>
inline void shift_along_dim(T* data, const DDim& input_dim, int64_t dim,
int64_t shift) {
if (dim < 0) {
dim += input_dim.size();
}
shift = shift % input_dim[dim];
if (shift < 0) {
shift += input_dim[dim];
}
auto outer_loops = 1;
for (auto i = 0; i < dim; i++) {
outer_loops *= input_dim[i];
}
auto slice_width = 1;
for (auto i = dim + 1; i < input_dim.size(); i++) {
slice_width *= input_dim[i];
}
VLOG(3) << "shift_along_dim_debug: input_dim: " << input_dim
<< "; dim: " << dim << "; shift: " << shift
<< "; outer_loops: " << outer_loops
<< "; slice_width: " << slice_width;
if (shift == 0) {
return;
}
std::vector<T> head;
auto head_size = slice_width * (input_dim[dim] - shift);
head.resize(head_size);
for (auto i = 0; i < outer_loops; i++) {
for (auto j = 0; j < head_size; j++) {
head[j] = data[i * input_dim[dim] * slice_width + j];
}
for (auto j = input_dim[dim] - shift; j < input_dim[dim]; j++) {
auto dst_pos = j - input_dim[dim] + shift;
for (auto k = 0; k < slice_width; k++) {
data[(i * input_dim[dim] + dst_pos) * slice_width + k] =
data[(i * input_dim[dim] + j) * slice_width + k];
}
}
for (auto j = 0; j < head_size; j++) {
data[(i * input_dim[dim] + shift) * slice_width + j] = head[j];
}
}
}
template <typename DeviceContext, typename T>
class RollKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* input_var = context.InputVar("X");
auto* output_var = context.OutputVar("Out");
auto& input = input_var->Get<LoDTensor>();
auto* output = output_var->GetMutable<LoDTensor>();
std::vector<int64_t> shifts = context.Attr<std::vector<int64_t>>("shifts");
std::vector<int64_t> dims = context.Attr<std::vector<int64_t>>("dims");
std::vector<T> out_vec;
TensorToVector(input, context.device_context(), &out_vec);
size_t nums = shifts.size();
const DDim input_dim = input.dims();
for (size_t i = 0; i < nums; i++) {
PADDLE_ENFORCE_EQ(
dims[i] < input_dim.size() && dims[i] >= (0 - input_dim.size()), true,
platform::errors::OutOfRange(
"Attr(dims[%d]) is out of range, It's expected "
"to be in range of [-%d, %d]. But received Attr(dims[%d]) = %d.",
i, input_dim.size(), input_dim.size() - 1, i, dims[i]));
shift_along_dim(out_vec.data(), input_dim, dims[i], shifts[i]);
}
output->mutable_data<T>(context.GetPlace());
framework::TensorFromVector(out_vec, context.device_context(), output);
output->Resize(input_dim);
}
};
template <typename DeviceContext, typename T>
class RollGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* input_var = context.InputVar(framework::GradVarName("Out"));
auto* output_var = context.OutputVar(framework::GradVarName("X"));
auto& input = input_var->Get<LoDTensor>();
auto* output = output_var->GetMutable<LoDTensor>();
std::vector<int64_t> shifts = context.Attr<std::vector<int64_t>>("shifts");
std::vector<int64_t> dims = context.Attr<std::vector<int64_t>>("dims");
std::vector<T> out_vec;
TensorToVector(input, context.device_context(), &out_vec);
size_t nums = shifts.size();
const DDim input_dim = input.dims();
for (size_t i = 0; i < nums; i++) {
shift_along_dim(out_vec.data(), input_dim, dims[i], 0 - shifts[i]);
}
output->mutable_data<T>(context.GetPlace());
framework::TensorFromVector(out_vec, context.device_context(), output);
output->Resize(input_dim);
}
};
} // namespace operators
} // namespace paddle
......@@ -34,8 +34,8 @@ class WhereIndexOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto output_type = framework::proto::VarType::INT64;
return framework::OpKernelType(output_type, ctx.device_context());
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Condition");
return framework::OpKernelType(data_type, ctx.device_context());
}
};
......@@ -55,4 +55,8 @@ class WhereIndexOpMaker : public framework::OpProtoAndCheckerMaker {
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(where_index, ops::WhereIndexOp,
ops::WhereIndexOpMaker);
REGISTER_OP_CPU_KERNEL(where_index, ops::CPUWhereIndexKernel<int64_t>);
REGISTER_OP_CPU_KERNEL(where_index, ops::CPUWhereIndexKernel<int64_t>,
ops::CPUWhereIndexKernel<int>,
ops::CPUWhereIndexKernel<bool>,
ops::CPUWhereIndexKernel<float>,
ops::CPUWhereIndexKernel<double>);
......@@ -35,40 +35,40 @@ class CUDAWhereIndexKernel : public framework::OpKernel<T> {
framework::Tensor cond_cpu;
framework::TensorCopy(*condition, platform::CPUPlace(), &cond_cpu);
const bool* cond_data = cond_cpu.data<bool>();
const T* cond_data = cond_cpu.data<T>();
int64_t numel = cond_cpu.numel();
auto dims = cond_cpu.dims();
int rank = dims.size();
thrust::host_vector<int> h_true_index;
thrust::host_vector<int64_t> h_true_index;
for (int64_t i = 0; i < numel; i++) {
if (cond_data[i]) {
if (static_cast<bool>(cond_data[i])) {
h_true_index.push_back(i);
}
}
thrust::device_vector<int> d_true_index = h_true_index;
int* ptr_true_index = thrust::raw_pointer_cast(d_true_index.data());
thrust::device_vector<int64_t> d_true_index = h_true_index;
int64_t* ptr_true_index = thrust::raw_pointer_cast(d_true_index.data());
size_t true_num = h_true_index.size();
out->Resize(framework::make_ddim({static_cast<int64_t>(true_num), rank}));
auto out_ptr = out->mutable_data<T>(context.GetPlace());
auto out_ptr = out->mutable_data<int64_t>(context.GetPlace());
if (true_num == 0) {
return;
}
thrust::host_vector<int> h_stride(rank, 0);
thrust::host_vector<int64_t> h_stride(rank, 0);
h_stride[rank - 1] = 1;
for (int i = rank - 2; i >= 0; i--) {
h_stride[i] = h_stride[i + 1] * dims[i + 1];
}
thrust::device_vector<int> d_stride = h_stride;
int* ptr_stride = thrust::raw_pointer_cast(d_stride.data());
thrust::device_vector<int64_t> d_stride = h_stride;
int64_t* ptr_stride = thrust::raw_pointer_cast(d_stride.data());
auto& dev_ctx = context.template device_context<CUDADeviceContext>();
WhereIndexFunctor<int*> functor(ptr_true_index, true_num, ptr_stride, rank,
out_ptr);
WhereIndexFunctor<int64_t> functor(ptr_true_index, true_num, ptr_stride,
rank, out_ptr);
platform::ForRange<CUDADeviceContext> for_range(dev_ctx, true_num);
for_range(functor);
}
......@@ -78,4 +78,8 @@ class CUDAWhereIndexKernel : public framework::OpKernel<T> {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(where_index, ops::CUDAWhereIndexKernel<int64_t>);
REGISTER_OP_CUDA_KERNEL(where_index, ops::CUDAWhereIndexKernel<int64_t>,
ops::CUDAWhereIndexKernel<int>,
ops::CUDAWhereIndexKernel<bool>,
ops::CUDAWhereIndexKernel<float>,
ops::CUDAWhereIndexKernel<double>);
......@@ -25,8 +25,8 @@ namespace operators {
template <typename T>
struct WhereIndexFunctor {
WhereIndexFunctor(const T& true_index, int true_num, const T& stride,
int rank, int64_t* out)
WhereIndexFunctor(const T* true_index, int true_num, const T* stride,
int rank, T* out)
: true_index_(true_index),
true_num_(true_num),
stride_(stride),
......@@ -34,18 +34,18 @@ struct WhereIndexFunctor {
out_ptr_(out) {}
HOSTDEVICE void operator()(size_t idx) const {
int index = true_index_[idx];
T index = true_index_[idx];
for (int j = 0; j < rank_; j++) {
out_ptr_[idx * rank_ + j] = index / stride_[j];
index -= out_ptr_[idx * rank_ + j] * stride_[j];
}
}
const T true_index_;
const T* true_index_;
int true_num_;
const T stride_;
const T* stride_;
int rank_;
int64_t* out_ptr_;
T* out_ptr_;
};
using CPUDeviceContext = paddle::platform::CPUDeviceContext;
......@@ -57,35 +57,35 @@ class CPUWhereIndexKernel : public framework::OpKernel<T> {
auto* condition = context.Input<framework::Tensor>("Condition");
auto* out = context.Output<framework::Tensor>("Out");
const bool* cond_data = condition->data<bool>();
const T* cond_data = condition->data<T>();
auto numel = condition->numel();
auto dims = condition->dims();
const int rank = dims.size();
std::vector<int> true_index;
std::vector<int64_t> true_index;
for (auto i = 0; i < numel; i++) {
if (cond_data[i]) {
if (static_cast<bool>(cond_data[i])) {
true_index.push_back(i);
}
}
auto true_num = true_index.size();
out->Resize(framework::make_ddim({static_cast<int64_t>(true_num), rank}));
auto out_ptr = out->mutable_data<T>(context.GetPlace());
auto out_ptr = out->mutable_data<int64_t>(context.GetPlace());
if (true_num == 0) {
return;
}
std::vector<int> stride(rank);
std::vector<int64_t> stride(rank);
stride[rank - 1] = 1;
for (int i = rank - 2; i >= 0; i--) {
stride[i] = stride[i + 1] * dims[i + 1];
}
auto& dev_ctx = context.template device_context<CPUDeviceContext>();
WhereIndexFunctor<int*> functor(true_index.data(), true_num, stride.data(),
rank, out_ptr);
WhereIndexFunctor<int64_t> functor(true_index.data(), true_num,
stride.data(), rank, out_ptr);
platform::ForRange<CPUDeviceContext> for_range(dev_ctx, true_num);
for_range(functor);
}
......
......@@ -154,7 +154,7 @@ from .tensor.linalg import norm #DEFINE_ALIAS
# from .tensor.linalg import transpose #DEFINE_ALIAS
from .tensor.linalg import dist #DEFINE_ALIAS
from .tensor.linalg import t #DEFINE_ALIAS
# from .tensor.linalg import cross #DEFINE_ALIAS
from .tensor.linalg import cross #DEFINE_ALIAS
# from .tensor.linalg import cholesky #DEFINE_ALIAS
# from .tensor.linalg import .tensordot #DEFINE_ALIAS
# from .tensor.manipulation import cast #DEFINE_ALIAS
......@@ -182,7 +182,7 @@ from .tensor.linalg import t #DEFINE_ALIAS
# from .tensor.manipulation import unstack #DEFINE_ALIAS
from .tensor.manipulation import flip #DEFINE_ALIAS
# from .tensor.manipulation import unbind #DEFINE_ALIAS
# from .tensor.manipulation import roll #DEFINE_ALIAS
from .tensor.manipulation import roll #DEFINE_ALIAS
from .tensor.search import argmax #DEFINE_ALIAS
# from .tensor.search import argmin #DEFINE_ALIAS
# from .tensor.search import argsort #DEFINE_ALIAS
......@@ -191,9 +191,9 @@ from .tensor.search import argmax #DEFINE_ALIAS
# from .tensor.search import masked_select #DEFINE_ALIAS
# from .tensor.search import topk #DEFINE_ALIAS
# from .tensor.search import where #DEFINE_ALIAS
# from .tensor.search import index_select #DEFINE_ALIAS
from .tensor.search import index_select #DEFINE_ALIAS
from .tensor.search import index_sample #DEFINE_ALIAS
# from .tensor.search import nonzero #DEFINE_ALIAS
from .tensor.search import nonzero #DEFINE_ALIAS
from .tensor.search import sort #DEFINE_ALIAS
# from .framework.framework import set_default_dtype #DEFINE_ALIAS
# from .framework.framework import get_default_dtype #DEFINE_ALIAS
......
......@@ -13275,9 +13275,11 @@ def where(condition):
out = layers.where(condition) # [[]]
"""
check_variable_and_dtype(condition, "condition", ['bool'], "where")
helper = LayerHelper("where_index", **locals())
if in_dygraph_mode():
return core.ops.where_index(condition)
out = helper.create_variable_for_type_inference(
dtype=core.VarDesc.VarType.INT64)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid.core as core
from op_test import OpTest
import paddle
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
class TestCrossOp(OpTest):
def setUp(self):
self.op_type = "cross"
self.initTestCase()
self.inputs = {
'X': np.random.random(self.shape).astype(self.dtype),
'Y': np.random.random(self.shape).astype(self.dtype)
}
self.init_output()
def initTestCase(self):
self.attrs = {'dim': -2}
self.dtype = np.float64
self.shape = (1024, 3, 1)
def init_output(self):
x = np.squeeze(self.inputs['X'], 2)
y = np.squeeze(self.inputs['Y'], 2)
z_list = []
for i in range(1024):
z_list.append(np.cross(x[i], y[i]))
self.outputs = {'Out': np.array(z_list).reshape(self.shape)}
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['X', 'Y'], 'Out')
class TestCrossOpCase1(TestCrossOp):
def initTestCase(self):
self.shape = (2048, 3)
self.dtype = np.float32
def init_output(self):
z_list = []
for i in range(2048):
z_list.append(np.cross(self.inputs['X'][i], self.inputs['Y'][i]))
self.outputs = {'Out': np.array(z_list).reshape(self.shape)}
class TestCrossAPI(unittest.TestCase):
def input_data(self):
self.data_x = np.array(
[[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0]])
self.data_y = np.array(
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]])
def test_cross_api(self):
self.input_data()
# case 1:
with program_guard(Program(), Program()):
x = fluid.layers.data(name='x', shape=[-1, 3])
y = fluid.layers.data(name='y', shape=[-1, 3])
z = paddle.cross(x, y, dim=1)
exe = fluid.Executor(fluid.CPUPlace())
res, = exe.run(feed={'x': self.data_x,
'y': self.data_y},
fetch_list=[z.name],
return_numpy=False)
expect_out = np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0],
[0.0, 0.0, 0.0]])
self.assertTrue(np.allclose(expect_out, np.array(res)))
# case 2:
with program_guard(Program(), Program()):
x = fluid.layers.data(name='x', shape=[-1, 3])
y = fluid.layers.data(name='y', shape=[-1, 3])
z = paddle.cross(x, y)
exe = fluid.Executor(fluid.CPUPlace())
res, = exe.run(feed={'x': self.data_x,
'y': self.data_y},
fetch_list=[z.name],
return_numpy=False)
expect_out = np.array([[-1.0, -1.0, -1.0], [2.0, 2.0, 2.0],
[-1.0, -1.0, -1.0]])
self.assertTrue(np.allclose(expect_out, np.array(res)))
def test_dygraph_api(self):
self.input_data()
# case 1:
with fluid.dygraph.guard():
x = fluid.dygraph.to_variable(self.data_x)
y = fluid.dygraph.to_variable(self.data_y)
z = paddle.cross(x, y)
np_z = z.numpy()
expect_out = np.array([[-1.0, -1.0, -1.0], [2.0, 2.0, 2.0],
[-1.0, -1.0, -1.0]])
self.assertTrue(np.allclose(expect_out, np_z))
# case 2:
with fluid.dygraph.guard():
x = fluid.dygraph.to_variable(self.data_x)
y = fluid.dygraph.to_variable(self.data_y)
z = paddle.cross(x, y, dim=1)
np_z = z.numpy()
expect_out = np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0],
[0.0, 0.0, 0.0]])
self.assertTrue(np.allclose(expect_out, np_z))
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle
import numpy as np
import paddle.fluid.core as core
from op_test import OpTest
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
class TestIndexSelectOp(OpTest):
def setUp(self):
self.op_type = "index_select"
self.init_dtype_type()
index_np = np.random.randint(
low=0, high=self.x_shape[self.dim], size=self.index_size)
x_np = np.random.random(self.x_shape).astype(self.x_type)
self.inputs = {'X': x_np, 'Index': index_np}
self.attrs = {'dim': self.dim}
outer_loop = np.prod(self.x_shape[:self.dim])
x_reshape = [outer_loop] + list(self.x_shape[self.dim:])
x_np_reshape = np.reshape(x_np, tuple(x_reshape))
out_list = []
for i in range(outer_loop):
for j in range(self.index_size):
out_list.append(x_np_reshape[i, index_np[j]])
self.out_shape = list(self.x_shape)
self.out_shape[self.dim] = self.index_size
self.out_shape = tuple(self.out_shape)
out = np.reshape(out_list, self.out_shape)
self.outputs = {'Out': out}
def init_dtype_type(self):
self.dim = 1
self.x_type = np.float64
self.index_type = np.int64
self.x_shape = (100, 4, 5)
self.index_size = 100
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['X'], 'Out')
class TestIndexSelectOpCase2(TestIndexSelectOp):
def init_dtype_type(self):
self.x_type = np.float32
self.index_type = np.int32
self.dim = -2
self.x_shape = (10, 10, 4, 10)
self.index_size = 10
class TestIndexSelectAPI(unittest.TestCase):
def input_data(self):
self.data_x = np.array([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0],
[9.0, 10.0, 11.0, 12.0]])
self.data_index = np.array([0, 1, 1]).astype('int32')
def test_index_select_api(self):
self.input_data()
# case 1:
with program_guard(Program(), Program()):
x = fluid.layers.data(name='x', shape=[-1, 4])
index = fluid.layers.data(
name='index', shape=[3], dtype='int32', append_batch_size=False)
z = paddle.index_select(x, index, dim=1)
exe = fluid.Executor(fluid.CPUPlace())
res, = exe.run(feed={'x': self.data_x,
'index': self.data_index},
fetch_list=[z.name],
return_numpy=False)
expect_out = np.array([[1.0, 2.0, 2.0], [5.0, 6.0, 6.0],
[9.0, 10.0, 10.0]])
self.assertTrue(np.allclose(expect_out, np.array(res)))
# case 2:
with program_guard(Program(), Program()):
x = fluid.layers.data(name='x', shape=[-1, 4])
index = fluid.layers.data(
name='index', shape=[3], dtype='int32', append_batch_size=False)
z = paddle.index_select(x, index)
exe = fluid.Executor(fluid.CPUPlace())
res, = exe.run(feed={'x': self.data_x,
'index': self.data_index},
fetch_list=[z.name],
return_numpy=False)
expect_out = np.array(
[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [5.0, 6.0, 7.0, 8.0]])
self.assertTrue(np.allclose(expect_out, np.array(res)))
def test_dygraph_api(self):
self.input_data()
# case 1:
with fluid.dygraph.guard():
x = fluid.dygraph.to_variable(self.data_x)
index = fluid.dygraph.to_variable(self.data_index)
z = paddle.index_select(x, index)
np_z = z.numpy()
expect_out = np.array(
[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [5.0, 6.0, 7.0, 8.0]])
self.assertTrue(np.allclose(expect_out, np_z))
# case 2:
with fluid.dygraph.guard():
x = fluid.dygraph.to_variable(self.data_x)
index = fluid.dygraph.to_variable(self.data_index)
z = paddle.index_select(x, index, dim=1)
np_z = z.numpy()
expect_out = np.array([[1.0, 2.0, 2.0], [5.0, 6.0, 6.0],
[9.0, 10.0, 10.0]])
self.assertTrue(np.allclose(expect_out, np_z))
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
class TestNonZeroAPI(unittest.TestCase):
def test_nonzero_api_as_tuple(self):
data = np.array([[True, False], [False, True]])
with program_guard(Program(), Program()):
x = fluid.layers.data(name='x', shape=[-1, 2])
y = paddle.nonzero(x, as_tuple=True)
self.assertEqual(type(y), tuple)
self.assertEqual(len(y), 2)
z = fluid.layers.concat(list(y), axis=1)
exe = fluid.Executor(fluid.CPUPlace())
res, = exe.run(feed={'x': data},
fetch_list=[z.name],
return_numpy=False)
expect_out = np.array([[0, 0], [1, 1]])
self.assertTrue(np.allclose(expect_out, np.array(res)))
data = np.array([True, True, False])
with program_guard(Program(), Program()):
x = fluid.layers.data(name='x', shape=[-1])
y = paddle.nonzero(x, as_tuple=True)
self.assertEqual(type(y), tuple)
self.assertEqual(len(y), 1)
z = fluid.layers.concat(list(y), axis=1)
exe = fluid.Executor(fluid.CPUPlace())
res, = exe.run(feed={'x': data},
fetch_list=[z.name],
return_numpy=False)
expect_out = np.array([[0], [1]])
self.assertTrue(np.allclose(expect_out, np.array(res)))
def test_nonzero_api(self):
data = np.array([[True, False], [False, True]])
with program_guard(Program(), Program()):
x = fluid.layers.data(name='x', shape=[-1, 2])
y = paddle.nonzero(x)
exe = fluid.Executor(fluid.CPUPlace())
res, = exe.run(feed={'x': data},
fetch_list=[y.name],
return_numpy=False)
expect_out = np.array([[0, 0], [1, 1]])
self.assertTrue(np.allclose(expect_out, np.array(res)))
data = np.array([True, True, False])
with program_guard(Program(), Program()):
x = fluid.layers.data(name='x', shape=[-1])
y = paddle.nonzero(x)
exe = fluid.Executor(fluid.CPUPlace())
res, = exe.run(feed={'x': data},
fetch_list=[y.name],
return_numpy=False)
expect_out = np.array([[0], [1]])
self.assertTrue(np.allclose(expect_out, np.array(res)))
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle
import numpy as np
import paddle.fluid.core as core
from op_test import OpTest
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
class TestRollOp(OpTest):
def setUp(self):
self.op_type = "roll"
self.init_dtype_type()
self.inputs = {'X': np.random.random(self.x_shape).astype(self.dtype)}
self.attrs = {'shifts': self.shifts, 'dims': self.dims}
self.outputs = {
'Out': np.roll(self.inputs['X'], self.attrs['shifts'],
self.attrs['dims'])
}
def init_dtype_type(self):
self.dtype = np.float64
self.x_shape = (100, 4, 5)
self.shifts = [101, -1]
self.dims = [0, -2]
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['X'], 'Out')
class TestRollOpCase2(TestRollOp):
def init_dtype_type(self):
self.dtype = np.float32
self.x_shape = (100, 100, 5)
self.shifts = [8, -1]
self.dims = [-1, -2]
class TestRollAPI(unittest.TestCase):
def input_data(self):
self.data_x = np.array(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]])
def test_index_select_api(self):
self.input_data()
# case 1:
with program_guard(Program(), Program()):
x = fluid.layers.data(name='x', shape=[-1, 3])
z = paddle.roll(x, shifts=1)
exe = fluid.Executor(fluid.CPUPlace())
res, = exe.run(feed={'x': self.data_x},
fetch_list=[z.name],
return_numpy=False)
expect_out = np.array([[9.0, 1.0, 2.0], [3.0, 4.0, 5.0],
[6.0, 7.0, 8.0]])
self.assertTrue(np.allclose(expect_out, np.array(res)))
# case 2:
with program_guard(Program(), Program()):
x = fluid.layers.data(name='x', shape=[-1, 3])
z = paddle.roll(x, shifts=1, dims=0)
exe = fluid.Executor(fluid.CPUPlace())
res, = exe.run(feed={'x': self.data_x},
fetch_list=[z.name],
return_numpy=False)
expect_out = np.array([[7.0, 8.0, 9.0], [1.0, 2.0, 3.0],
[4.0, 5.0, 6.0]])
self.assertTrue(np.allclose(expect_out, np.array(res)))
def test_dygraph_api(self):
self.input_data()
# case 1:
with fluid.dygraph.guard():
x = fluid.dygraph.to_variable(self.data_x)
z = paddle.roll(x, shifts=1)
np_z = z.numpy()
expect_out = np.array([[9.0, 1.0, 2.0], [3.0, 4.0, 5.0],
[6.0, 7.0, 8.0]])
self.assertTrue(np.allclose(expect_out, np_z))
# case 2:
with fluid.dygraph.guard():
x = fluid.dygraph.to_variable(self.data_x)
z = paddle.roll(x, shifts=1, dims=0)
np_z = z.numpy()
expect_out = np.array([[7.0, 8.0, 9.0], [1.0, 2.0, 3.0],
[4.0, 5.0, 6.0]])
self.assertTrue(np.allclose(expect_out, np_z))
if __name__ == "__main__":
unittest.main()
......@@ -109,12 +109,6 @@ class TestWhereRaiseError(unittest.TestCase):
self.assertRaises(TypeError, test_type)
def test_dtype():
data = fluid.data(shape=[10], dtype="float32", name="input")
fluid.layers.where(data)
self.assertRaises(TypeError, test_dtype)
if __name__ == "__main__":
unittest.main()
......@@ -134,7 +134,7 @@ from .linalg import norm #DEFINE_ALIAS
# from .linalg import transpose #DEFINE_ALIAS
from .linalg import dist #DEFINE_ALIAS
from .linalg import t #DEFINE_ALIAS
# from .linalg import cross #DEFINE_ALIAS
from .linalg import cross #DEFINE_ALIAS
# from .linalg import cholesky #DEFINE_ALIAS
# from .manipulation import cast #DEFINE_ALIAS
# from .manipulation import concat #DEFINE_ALIAS
......@@ -161,7 +161,7 @@ from .linalg import t #DEFINE_ALIAS
# from .manipulation import unstack #DEFINE_ALIAS
from .manipulation import flip #DEFINE_ALIAS
# from .manipulation import unbind #DEFINE_ALIAS
# from .manipulation import roll #DEFINE_ALIAS
from .manipulation import roll #DEFINE_ALIAS
from .search import argmax #DEFINE_ALIAS
# from .search import argmin #DEFINE_ALIAS
# from .search import argsort #DEFINE_ALIAS
......@@ -170,7 +170,7 @@ from .search import argmax #DEFINE_ALIAS
# from .search import masked_select #DEFINE_ALIAS
# from .search import topk #DEFINE_ALIAS
from .search import where #DEFINE_ALIAS
# from .search import index_select #DEFINE_ALIAS
from .search import index_select #DEFINE_ALIAS
from .search import index_sample # DEFINE_ALIAS
# from .search import nonzero #DEFINE_ALIAS
from .search import nonzero #DEFINE_ALIAS
from .search import sort #DEFINE_ALIAS
......@@ -24,7 +24,7 @@ __all__ = [
# 'transpose',
'dist',
't',
# 'cross',
'cross',
# 'cholesky',
# 'tensordot'
]
......@@ -529,3 +529,65 @@ def t(input, name=None):
'XShape': [input_shape]},
attrs={'axis': [1, 0]})
return out
def cross(input, other, dim=None):
"""
Returns the cross product of vectors in dimension `dim` of the `input` and `other` tensor.
Inputs must have the same shape, and the size of their dim-th dimension should be equla to 3.
If `dim` is not given, it defaults to the first dimension found with the size 3.
Args:
input (Variable): The first input tensor variable.
other (Variable): The second input tensor variable.
dim (int): The dimension to take the cross-product in.
Returns:
Variable: A Tensor with same data type as `input`.
Examples:
.. code-block:: python
import paddle
import paddle.fluid as fluid
import numpy as np
data_x = np.array([[1.0, 1.0, 1.0],
[2.0, 2.0, 2.0],
[3.0, 3.0, 3.0]])
data_y = np.array([[1.0, 1.0, 1.0],
[1.0, 1.0, 1.0],
[1.0, 1.0, 1.0]])
with fluid.dygraph.guard():
x = fluid.dygraph.to_variable(data_x)
y = fluid.dygraph.to_variable(data_y)
out_z1 = paddle.cross(x, y)
print(out_z1.numpy())
#[[-1. -1. -1.]
# [ 2. 2. 2.]
# [-1. -1. -1.]]
out_z2 = paddle.cross(x, y, dim=1)
print(out_z2.numpy())
#[[0. 0. 0.]
# [0. 0. 0.]
# [0. 0. 0.]]
"""
helper = LayerHelper("cross", **locals())
if in_dygraph_mode():
if dim:
return core.ops.cross(input, other, 'dim', dim)
else:
return core.ops.cross(input, other)
out = helper.create_variable_for_type_inference(input.dtype)
attrs = dict()
if dim:
attrs['dim'] = dim
helper.append_op(
type='cross',
inputs={'X': input,
'Y': other},
outputs={'Out': out},
attrs=attrs)
return out
......@@ -14,7 +14,7 @@
from __future__ import print_function
from ..fluid.layers import core
from ..fluid.layers import core, reshape
from ..fluid.layer_helper import LayerHelper
from ..fluid.framework import Variable, OpProtoHolder, in_dygraph_mode, convert_np_dtype_to_dtype_
from ..fluid.data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype
......@@ -46,7 +46,7 @@ __all__ = [
# 'unstack',
'flip',
# 'unbind',
# 'roll'
'roll'
]
......@@ -98,3 +98,74 @@ def flip(input, dims, name=None):
outputs={"Out": out},
attrs={"dims": dims})
return out
def roll(input, shifts, dims=None):
"""
Roll the `input` tensor along the given dimension(s). Elements that are shifted beyond
the last position are re-introduced at the first position. If a dimension is not specified,
the tensor will be flattened before rolling and then restored to the original shape.
Args:
input (Variable): The input tensor variable.
shifts (int|list|tuple): The number of places by which the elements
of the `input` tensor are shifted.
dims (int|list|tuple|None): Dimentions along which to roll.
Returns:
Variable: A Tensor with same data type as `input`.
Examples:
.. code-block:: python
import numpy as np
import paddle
import paddle.fluid as fluid
data = np.array([[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0],
[7.0, 8.0, 9.0]])
with fluid.dygraph.guard():
x = fluid.dygraph.to_variable(data)
out_z1 = paddle.roll(x, shifts=1)
print(out_z1.numpy())
#[[9. 1. 2.]
# [3. 4. 5.]
# [6. 7. 8.]]
out_z2 = paddle.roll(x, shifts=1, dims=0)
print(out_z2.numpy())
#[[7. 8. 9.]
# [1. 2. 3.]
# [4. 5. 6.]]
"""
helper = LayerHelper("roll", **locals())
origin_shape = input.shape
if type(shifts) == int:
shifts = [shifts]
if type(dims) == int:
dims = [dims]
if dims:
check_type(dims, 'dims', (list, tuple), 'roll')
check_type(shifts, 'shifts', (list, tuple), 'roll')
if in_dygraph_mode():
if dims is None:
input = core.ops.reshape(input, 'shape', [-1, 1])
dims = [0]
out = core.ops.roll(input, 'dims', dims, 'shifts', shifts)
return core.ops.reshape(out, 'shape', origin_shape)
out = helper.create_variable_for_type_inference(input.dtype)
if dims is None:
input = reshape(input, shape=[-1, 1])
dims = [0]
helper.append_op(
type='roll',
inputs={'X': input},
outputs={'Out': out},
attrs={'dims': dims,
'shifts': shifts})
out = reshape(out, shape=origin_shape, inplace=True)
return out
......@@ -27,8 +27,8 @@ __all__ = [
# 'masked_select',
# 'topk',
'where',
# 'index_select',
# 'nonzero',
'index_select',
'nonzero',
'sort',
'index_sample'
]
......@@ -126,6 +126,151 @@ def argmax(input, axis=None, dtype=None, out=None, keepdims=False, name=None):
return out
def index_select(input, index, dim=0):
"""
Returns a new tensor which indexes the `input` tensor along dimension `dim` using
the entries in `index` which is a Tensor. The returned tensor has the same number
of dimensions as the original `input` tensor. The dim-th dimension has the same
size as the length of `index`; other dimensions have the same size as in the `input` tensor.
Args:
input (Variable): The input tensor variable.
index (Variable): The 1-D tensor containing the indices to index.
dim (int): The dimension in which we index.
Returns:
Variable: A Tensor with same data type as `input`.
Examples:
.. code-block:: python
import paddle
import paddle.fluid as fluid
import numpy as np
data = np.array([[1.0, 2.0, 3.0, 4.0],
[5.0, 6.0, 7.0, 8.0],
[9.0, 10.0, 11.0, 12.0]])
data_index = np.array([0, 1, 1]).astype('int32')
with fluid.dygraph.guard():
x = fluid.dygraph.to_variable(data)
index = fluid.dygraph.to_variable(data_index)
out_z1 = paddle.index_select(x, index)
print(out_z1.numpy())
#[[1. 2. 3. 4.]
# [5. 6. 7. 8.]
# [5. 6. 7. 8.]]
out_z2 = paddle.index_select(x, index, dim=1)
print(out_z2.numpy())
#[[ 1. 2. 2.]
# [ 5. 6. 6.]
# [ 9. 10. 10.]]
"""
helper = LayerHelper("index_select", **locals())
if in_dygraph_mode():
return core.ops.index_select(input, index, 'dim', dim)
check_variable_and_dtype(input, 'x',
['float32', 'float64', 'int32', 'int64'],
'paddle.tensor.search.index_sample')
check_variable_and_dtype(index, 'index', ['int32', 'int64'],
'paddle.tensor.search.index_sample')
out = helper.create_variable_for_type_inference(input.dtype)
helper.append_op(
type='index_select',
inputs={'X': input,
'Index': index},
outputs={'Out': out},
attrs={'dim': dim})
return out
def nonzero(input, as_tuple=False):
"""
Return a tensor containing the indices of all non-zero elements of the `input`
tensor. If as_tuple is True, return a tuple of 1-D tensors, one for each dimension
in `input`, each containing the indices (in that dimension) of all non-zero elements
of `input`. Given a n-Dimensional `input` tensor with shape [x_1, x_2, ..., x_n], If
as_tuple is False, we can get a output tensor with shape [z, n], where `z` is the
number of all non-zero elements in the `input` tensor. If as_tuple is True, we can get
a 1-D tensor tuple of length `n`, and the shape of each 1-D tensor is [z, 1].
Args:
inputs (Variable): The input tensor variable.
as_tuple (bool): Return type, Tensor or tuple of Tensor.
Returns:
Variable. The data type is int64.
Examples:
.. code-block:: python
import paddle
import paddle.fluid as fluid
import numpy as np
data1 = np.array([[1.0, 0.0, 0.0],
[0.0, 2.0, 0.0],
[0.0, 0.0, 3.0]])
data2 = np.array([0.0, 1.0, 0.0, 3.0])
data3 = np.array([0.0, 0.0, 0.0])
with fluid.dygraph.guard():
x1 = fluid.dygraph.to_variable(data1)
x2 = fluid.dygraph.to_variable(data2)
x3 = fluid.dygraph.to_variable(data3)
out_z1 = paddle.nonzero(x1)
print(out_z1.numpy())
#[[0 0]
# [1 1]
# [2 2]]
out_z1_tuple = paddle.nonzero(x1, as_tuple=True)
for out in out_z1_tuple:
print(out.numpy())
#[[0]
# [1]
# [2]]
#[[0]
# [1]
# [2]]
out_z2 = paddle.nonzero(x2)
print(out_z2.numpy())
#[[1]
# [3]]
out_z2_tuple = paddle.nonzero(x2, as_tuple=True)
for out in out_z2_tuple:
print(out.numpy())
#[[1]
# [3]]
out_z3 = paddle.nonzero(x3)
print(out_z3.numpy())
#[]
out_z3_tuple = paddle.nonzero(x3, as_tuple=True)
for out in out_z3_tuple:
print(out.numpy())
#[]
"""
list_out = []
shape = input.shape
rank = len(shape)
if in_dygraph_mode():
outs = core.ops.where_index(input)
else:
outs = layers.where(input)
if not as_tuple:
return outs
elif rank == 1:
return tuple([outs])
else:
for i in range(rank):
list_out.append(
layers.slice(
outs, axes=[rank - 1], starts=[i], ends=[i + 1]))
return tuple(list_out)
def sort(input, axis=-1, descending=False, out=None, name=None):
"""
This OP sorts the input along the given axis, and returns sorted output
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册