未验证 提交 f147fc99 编写于 作者: H Huihuang Zheng 提交者: GitHub

Put_along_axis (based on PR #37921 by Xu Huang) (#38608)

Paddle new APIs: put_along_axis.

Xu Huang is on holiday so we created this PR to work on it. It is based on his PR: https://github.com/PaddlePaddle/Paddle/pull/37921
上级 f1366d58
......@@ -34,9 +34,17 @@ class ReduceAdd {
*self_data += *src_data;
}
};
static ReduceAdd reduce_add;
class ReduceMultiply {
public:
template <typename tensor_t>
void operator()(tensor_t* self_data, tensor_t* src_data) const {
*self_data *= *src_data;
}
};
static ReduceMultiply reduce_mul;
template <typename tensor_t, typename index_t = int64_t,
bool is_scatter_like = true>
struct cpu_gather_scatter_functor {
......@@ -75,7 +83,6 @@ struct cpu_gather_scatter_functor {
for (int i = dim + 1; i < index_dims.size(); i++) {
outer_dim_size *= index_dims[i];
}
int64_t index_idx = 0;
int64_t self_idx, src_idx;
......@@ -141,8 +148,55 @@ void cpu_scatter_add_kernel(Tensor self, int dim, const Tensor& index,
self, dim, index, src, "scatter_add_cpu", reduce_add, ctx);
}
template <typename tensor_t, typename index_t>
void cpu_scatter_mul_kernel(Tensor self, int dim, const Tensor& index,
Tensor src, const platform::DeviceContext& ctx) {
cpu_gather_scatter_functor<tensor_t, index_t,
/*is_scatter_like=*/true>()(
self, dim, index, src, "scatter_mul_cpu", reduce_mul, ctx);
}
template <typename tensor_t, typename index_t>
void cpu_scatter_input_grad_kernel(Tensor self, int dim, const Tensor& index,
Tensor output,
const platform::DeviceContext& ctx) {
auto* index_data = index.data<index_t>();
auto* output_data = output.data<tensor_t>();
auto index_dims = index.dims();
auto output_dims = output.dims();
int64_t inner_dim_size = 1;
int64_t outer_dim_size = 1;
int select_dim_size = index_dims[dim];
int output_select_dim_size = output_dims[dim];
for (int64_t i = 0; i < dim; ++i) {
inner_dim_size *= index_dims[i];
}
for (int i = dim + 1; i < index_dims.size(); i++) {
outer_dim_size *= index_dims[i];
}
int64_t index_idx = 0;
for (int64_t i = 0; i < inner_dim_size; i++) {
for (int64_t j = 0; j < select_dim_size; j++) {
for (int64_t k = 0; k < outer_dim_size; k++) {
int64_t index = index_data[index_idx];
int64_t replace_index = k + index * outer_dim_size +
i * outer_dim_size * output_select_dim_size;
output_data[replace_index] = 0;
index_idx++;
}
}
}
}
Instantiate_Template_Function(cpu_gather_kernel)
Instantiate_Template_Function(cpu_scatter_add_kernel)
Instantiate_Template_Function(cpu_scatter_assign_kernel)
Instantiate_Template_Function(cpu_scatter_add_kernel)
Instantiate_Template_Function(cpu_scatter_mul_kernel)
Instantiate_Template_Function(cpu_scatter_input_grad_kernel)
} // namespace operators
} // namespace paddle
......@@ -45,6 +45,16 @@ class ReduceAdd {
};
static ReduceAdd reduce_add;
class ReduceMul {
public:
template <typename tensor_t>
__device__ void operator()(tensor_t* self_data, tensor_t* src_data) const {
*self_data *= *src_data;
// TODO(huangxu96) platform::CudaAtomicMul(*self_data, *src_data);
}
};
static ReduceMul reduce_mul;
template <typename tensor_t, typename index_t, typename func_t,
bool is_scatter_like = true>
__global__ void GatherScatterGPUKernel(
......@@ -141,6 +151,14 @@ void gpu_gather_kernel(Tensor self, int dim, const Tensor& index, Tensor result,
return;
}
template <typename tensor_t, typename index_t>
void gpu_scatter_assign_kernel(Tensor self, int dim, const Tensor& index,
Tensor src, const platform::DeviceContext& ctx) {
gpu_gather_scatter_functor<tensor_t, index_t,
/*is_scatter_like=*/true>()(
self, dim, index, src, "scatter_assign_gpu", tensor_assign, ctx);
}
template <typename tensor_t, typename index_t>
void gpu_scatter_add_kernel(Tensor self, int dim, const Tensor& index,
Tensor src, const platform::DeviceContext& ctx) {
......@@ -149,9 +167,72 @@ void gpu_scatter_add_kernel(Tensor self, int dim, const Tensor& index,
self, dim, index, src, "scatter_add_gpu", reduce_add, ctx);
}
namespace plat = paddle::platform;
template <typename tensor_t, typename index_t>
void gpu_scatter_mul_kernel(Tensor self, int dim, const Tensor& index,
Tensor src, const platform::DeviceContext& ctx) {
gpu_gather_scatter_functor<tensor_t, index_t,
/*is_scatter_like=*/true>()(
self, dim, index, src, "scatter_mul_gpu", reduce_mul, ctx);
}
template <typename tensor_t, typename index_t>
__global__ void ScatterInputGradGPUKernel(
tensor_t* grad_data, int dim, const index_t* index_data,
int64_t inner_dim_size, int select_dim_size, int grad_select_dim_size,
int64_t outer_dim_size, int64_t numel) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid >= numel) return;
int64_t i, j, k;
i = tid / (select_dim_size * outer_dim_size);
int64_t remind = tid % (select_dim_size * outer_dim_size);
j = remind / outer_dim_size;
k = remind % outer_dim_size;
index_t index = index_data[tid];
int64_t replace_index =
k + index * outer_dim_size + i * outer_dim_size * grad_select_dim_size;
grad_data[replace_index] = 0;
}
template <typename tensor_t, typename index_t>
void gpu_scatter_input_grad_kernel(Tensor self, int dim, const Tensor& index,
Tensor grad,
const platform::DeviceContext& ctx) {
auto* index_data = index.data<index_t>();
auto* grad_data = grad.data<tensor_t>();
auto index_dims = index.dims();
auto grad_dims = grad.dims();
int64_t index_size = index.numel();
int64_t inner_dim_size = 1;
int64_t outer_dim_size = 1;
int select_dim_size = index_dims[dim];
int grad_select_dim_size = grad_dims[dim];
for (int64_t i = 0; i < dim; ++i) {
inner_dim_size *= index_dims[i];
}
for (int i = dim + 1; i < index_dims.size(); i++) {
outer_dim_size *= index_dims[i];
}
int64_t slice_size = 1;
for (int i = 1; i < grad_dims.size(); ++i) slice_size *= grad_dims[i];
int block = 512;
int64_t n = slice_size * index_size;
int64_t grid = (n + block - 1) / block;
auto stream =
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream();
ScatterInputGradGPUKernel<tensor_t, index_t><<<grid, block, 0, stream>>>(
grad_data, dim, index_data, inner_dim_size, select_dim_size,
grad_select_dim_size, outer_dim_size, index_size);
}
Instantiate_Template_Function(gpu_gather_kernel)
Instantiate_Template_Function(gpu_scatter_add_kernel)
Instantiate_Template_Function(gpu_scatter_assign_kernel)
Instantiate_Template_Function(gpu_scatter_add_kernel)
Instantiate_Template_Function(gpu_scatter_mul_kernel)
Instantiate_Template_Function(gpu_scatter_input_grad_kernel)
} // namespace operators
} // namespace paddle
......@@ -41,17 +41,42 @@ template <typename tensor_t, typename index_t>
void cpu_gather_kernel(Tensor self, int dim, const Tensor& index, Tensor result,
const platform::DeviceContext& ctx);
template <typename tensor_t, typename index_t>
void cpu_scatter_assign_kernel(Tensor self, int dim, const Tensor& index,
Tensor src, const platform::DeviceContext& ctx);
template <typename tensor_t, typename index_t>
void cpu_scatter_add_kernel(Tensor self, int dim, const Tensor& index,
Tensor src, const platform::DeviceContext& ctx);
template <typename tensor_t, typename index_t>
void cpu_scatter_mul_kernel(Tensor self, int dim, const Tensor& index,
Tensor src, const platform::DeviceContext& ctx);
template <typename tensor_t, typename index_t>
void cpu_scatter_input_grad_kernel(Tensor self, int dim, const Tensor& index,
Tensor result,
const platform::DeviceContext& ctx);
template <typename tensor_t, typename index_t>
void gpu_gather_kernel(Tensor self, int dim, const Tensor& index, Tensor result,
const platform::DeviceContext& ctx);
template <typename tensor_t, typename index_t>
void gpu_scatter_assign_kernel(Tensor self, int dim, const Tensor& index,
Tensor src, const platform::DeviceContext& ctx);
template <typename tensor_t, typename index_t>
void gpu_scatter_add_kernel(Tensor self, int dim, const Tensor& index,
Tensor src, const platform::DeviceContext& ctx);
template <typename tensor_t, typename index_t>
void gpu_scatter_mul_kernel(Tensor self, int dim, const Tensor& index,
Tensor src, const platform::DeviceContext& ctx);
template <typename tensor_t, typename index_t>
void gpu_scatter_input_grad_kernel(Tensor self, int dim, const Tensor& index,
Tensor result,
const platform::DeviceContext& ctx);
} // namespace operators
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/put_along_axis_op.h"
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace operators {
class PutAlongAxisOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "PutAlongAxis");
OP_INOUT_CHECK(ctx->HasInput("Index"), "Input", "Index", "PutAlongAxis");
OP_INOUT_CHECK(ctx->HasInput("Value"), "Input", "Value", "PutAlongAxis");
OP_INOUT_CHECK(ctx->HasOutput("Result"), "Output", "Result",
"PutAlongAxis");
auto index_dim = ctx->GetInputDim("Index");
ctx->SetOutputDim("Result", index_dim);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.device_context());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const framework::Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
};
class PutAlongAxisOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Input", "The input tensor of PutAlongAxisOp");
AddInput("Index", "The index tensor of PutAlongAxisOp");
AddInput("Value", "The value tensor of PutAlongAxisOp");
AddOutput("Result", "The result tensor of PutAlongAxisOp");
AddAttr<int>("Axis", "The axis that we do PutAlongAxis operation");
AddAttr<std::string>("Reduce", "The reduce operation for scatter")
.SetDefault("assign");
AddComment(R"DOC(
PutAlongAxis Operator.)
)DOC");
}
};
class PutAlongAxisGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
ctx->SetOutputDim(framework::GradVarName("Input"),
ctx->GetInputDim("Input"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Result")),
ctx.device_context());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const framework::Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
};
template <typename T>
class PutAlongAxisGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("put_along_axis_grad");
op->SetInput("Index", this->Input("Index"));
op->SetInput("Input", this->Input("Input"));
op->SetInput(framework::GradVarName("Result"), this->OutputGrad("Result"));
op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input"));
op->SetOutput(framework::GradVarName("Value"), this->InputGrad("Value"));
op->SetAttrMap(this->Attrs());
}
};
DECLARE_INPLACE_OP_INFERER(PutAlongAxisInplaceInferer, {"Input", "Result"});
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(put_along_axis, ops::PutAlongAxisOp, ops::PutAlongAxisOpMaker,
ops::PutAlongAxisGradOpMaker<paddle::framework::OpDesc>,
ops::PutAlongAxisGradOpMaker<paddle::imperative::OpBase>,
paddle::operators::PutAlongAxisInplaceInferer);
REGISTER_OPERATOR(put_along_axis_grad, ops::PutAlongAxisGradOp);
REGISTER_OP_CPU_KERNEL(put_along_axis, ops::PutAlongAxisOpKernel<float>,
ops::PutAlongAxisOpKernel<double>,
ops::PutAlongAxisOpKernel<int>,
ops::PutAlongAxisOpKernel<uint8_t>,
ops::PutAlongAxisOpKernel<int64_t>);
REGISTER_OP_CPU_KERNEL(put_along_axis_grad,
ops::PutAlongAxisGradOpKernel<float>,
ops::PutAlongAxisGradOpKernel<double>,
ops::PutAlongAxisGradOpKernel<int>,
ops::PutAlongAxisGradOpKernel<uint8_t>,
ops::PutAlongAxisGradOpKernel<int64_t>);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/put_along_axis_op.h"
namespace paddle {
namespace operators {
template <typename T>
class PutAlongAxisCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true,
platform::errors::PreconditionNotMet(
"PutAlongAxisCUDAKernel only runs on GPU device."));
auto input = ctx.Input<Tensor>("Input");
auto axis = ctx.Attr<int>("Axis");
auto value = ctx.Input<Tensor>("Value");
auto index = ctx.Input<Tensor>("Index");
auto reduce_op = ctx.Attr<std::string>("Reduce");
auto result = ctx.Output<Tensor>("Result");
const platform::DeviceContext &device_ctx = ctx.device_context();
const auto &index_type = index->type();
framework::TensorCopy(*input, ctx.GetPlace(), result);
if (reduce_op == "add") {
if (index_type == framework::proto::VarType::INT32) {
gpu_scatter_add_kernel<T, int32_t>(*result, axis, *index, *value,
device_ctx);
} else if (index_type == framework::proto::VarType::INT64) {
gpu_scatter_add_kernel<T, int64_t>(*result, axis, *index, *value,
device_ctx);
}
} else if (reduce_op == "multiply" || reduce_op == "mul") {
if (index_type == framework::proto::VarType::INT32) {
gpu_scatter_mul_kernel<T, int32_t>(*result, axis, *index, *value,
device_ctx);
} else if (index_type == framework::proto::VarType::INT64) {
gpu_scatter_mul_kernel<T, int64_t>(*result, axis, *index, *value,
device_ctx);
}
} else if (reduce_op == "assign") {
if (index_type == framework::proto::VarType::INT32) {
gpu_scatter_assign_kernel<T, int32_t>(*result, axis, *index, *value,
device_ctx);
} else if (index_type == framework::proto::VarType::INT64) {
gpu_scatter_assign_kernel<T, int64_t>(*result, axis, *index, *value,
device_ctx);
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"can not support reduce_op: '%s' for scatter kernel, only "
"support reduce op: 'add‘, 'assign', 'mul' and 'multiply', the "
"defalut reduce op is 'assign' ",
reduce_op));
return;
}
}
};
template <typename T>
class PutAlongAxisGradOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true,
platform::errors::PreconditionNotMet(
"PutAlongAxisGradOpCUDAKernel only runs on GPU."));
auto input_grad = ctx.Output<Tensor>(framework::GradVarName("Input"));
auto value_grad = ctx.Output<Tensor>(framework::GradVarName("Value"));
auto index = ctx.Input<Tensor>("Index");
auto result_grad = ctx.Input<Tensor>(framework::GradVarName("Result"));
auto axis = ctx.Attr<int>("Axis");
const auto &index_type = index->type();
if (input_grad) {
framework::TensorCopy(*result_grad, ctx.GetPlace(), input_grad);
if (index_type == framework::proto::VarType::INT32) {
gpu_scatter_input_grad_kernel<T, int32_t>(
*result_grad, axis, *index, *input_grad, ctx.device_context());
} else {
gpu_scatter_input_grad_kernel<T, int64_t>(
*result_grad, axis, *index, *input_grad, ctx.device_context());
}
}
if (value_grad) {
value_grad->Resize(index->dims());
value_grad->mutable_data<T>(ctx.GetPlace());
if (index_type == framework::proto::VarType::INT32) {
gpu_gather_kernel<T, int32_t>(
*result_grad, axis, *index, *value_grad,
ctx.device_context()); // the gradient of scatter is gather
} else if (index_type == framework::proto::VarType::INT64) {
gpu_gather_kernel<T, int64_t>(*result_grad, axis, *index, *value_grad,
ctx.device_context());
}
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(put_along_axis, ops::PutAlongAxisCUDAKernel<float>,
ops::PutAlongAxisCUDAKernel<double>,
ops::PutAlongAxisCUDAKernel<int64_t>,
ops::PutAlongAxisCUDAKernel<int>,
ops::PutAlongAxisCUDAKernel<plat::float16>);
REGISTER_OP_CUDA_KERNEL(put_along_axis_grad,
ops::PutAlongAxisGradOpCUDAKernel<float>,
ops::PutAlongAxisGradOpCUDAKernel<double>,
ops::PutAlongAxisGradOpCUDAKernel<int64_t>,
ops::PutAlongAxisGradOpCUDAKernel<int>,
ops::PutAlongAxisGradOpCUDAKernel<plat::float16>);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/gather_scatter_kernel.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
class PutAlongAxisOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true,
platform::errors::PreconditionNotMet(
"PutAlongAxisOpKernel only runs on CPU."));
auto input = ctx.Input<Tensor>("Input");
auto axis = ctx.Attr<int>("Axis");
auto value = ctx.Input<Tensor>("Value");
auto index = ctx.Input<Tensor>("Index");
auto reduce_op = ctx.Attr<std::string>("Reduce");
auto result = ctx.Output<Tensor>("Result");
framework::TensorCopy(*input, ctx.GetPlace(), result);
const platform::DeviceContext &device_ctx = ctx.device_context();
const auto &index_type = index->type();
if (reduce_op == "add") {
if (index_type == framework::proto::VarType::INT32) {
cpu_scatter_add_kernel<T, int32_t>(*result, axis, *index, *value,
device_ctx);
} else if (index_type == framework::proto::VarType::INT64) {
cpu_scatter_add_kernel<T, int64_t>(*result, axis, *index, *value,
device_ctx);
}
} else if (reduce_op == "multiply" || reduce_op == "mul") {
if (index_type == framework::proto::VarType::INT32) {
cpu_scatter_mul_kernel<T, int32_t>(*result, axis, *index, *value,
device_ctx);
} else if (index_type == framework::proto::VarType::INT64) {
cpu_scatter_mul_kernel<T, int64_t>(*result, axis, *index, *value,
device_ctx);
}
} else if (reduce_op == "assign") {
if (index_type == framework::proto::VarType::INT32) {
cpu_scatter_assign_kernel<T, int32_t>(*result, axis, *index, *value,
device_ctx);
} else if (index_type == framework::proto::VarType::INT64) {
cpu_scatter_assign_kernel<T, int64_t>(*result, axis, *index, *value,
device_ctx);
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"can not support reduce_op: '%s' for scatter kernel, only "
"support reduce op: 'add‘, 'assign', 'mul' and 'multiply', the "
"defalut reduce "
"op is 'assign' ",
reduce_op));
return;
}
}
};
template <typename T>
class PutAlongAxisGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true,
platform::errors::PreconditionNotMet(
"PutAlongAxisGradOpKernel only runs on CPU."));
auto input_grad = ctx.Output<Tensor>(framework::GradVarName("Input"));
auto value_grad = ctx.Output<Tensor>(framework::GradVarName("Value"));
auto index = ctx.Input<Tensor>("Index");
auto result_grad = ctx.Input<Tensor>(framework::GradVarName("Result"));
auto axis = ctx.Attr<int>("Axis");
const auto &index_type = index->type();
if (input_grad) {
framework::TensorCopy(*result_grad, ctx.GetPlace(), input_grad);
if (index_type == framework::proto::VarType::INT32) {
cpu_scatter_input_grad_kernel<T, int32_t>(
// Here passing an unused argument *result_grad, because it's
// convenient to instantiate a bunch of template function with the
// same arguments list.
*result_grad, axis, *index, *input_grad, ctx.device_context());
} else {
cpu_scatter_input_grad_kernel<T, int64_t>(
*result_grad, axis, *index, *input_grad, ctx.device_context());
}
}
if (value_grad) {
value_grad->Resize(index->dims());
value_grad->mutable_data<T>(ctx.GetPlace());
if (index_type == framework::proto::VarType::INT32) {
cpu_gather_kernel<T, int32_t>(*result_grad, axis, *index, *value_grad,
ctx.device_context());
} else if (index_type == framework::proto::VarType::INT64) {
cpu_gather_kernel<T, int64_t>(*result_grad, axis, *index, *value_grad,
ctx.device_context());
}
}
}
};
} // namespace operators
} // namespace paddle
......@@ -32,7 +32,6 @@ class TakeAlongAxisCUDAKernel : public framework::OpKernel<T> {
auto result = ctx.Output<Tensor>("Result");
result->Resize(index->dims());
result->mutable_data<T>(ctx.GetPlace());
const auto &index_type = index->type();
if (index_type == framework::proto::VarType::INT32) {
gpu_gather_kernel<T, int32_t>(*input, axis, *index, *result,
......
......@@ -159,6 +159,7 @@ from .tensor.manipulation import roll # noqa: F401
from .tensor.manipulation import chunk # noqa: F401
from .tensor.manipulation import tolist # noqa: F401
from .tensor.manipulation import take_along_axis # noqa: F401
from .tensor.manipulation import put_along_axis # noqa: F401
from .tensor.manipulation import tensordot # noqa: F401
from .tensor.manipulation import as_complex # noqa: F401
from .tensor.manipulation import as_real # noqa: F401
......@@ -611,4 +612,6 @@ __all__ = [ # noqa
'repeat_interleave',
'clone',
'renorm',
'take_along_axis',
'put_along_axis',
]
......@@ -139,8 +139,8 @@ def get_numeric_gradient(place,
elif tensor_to_check_dtype == core.VarDesc.VarType.COMPLEX128:
tensor_tp_check_dtype = np.complex128
else:
raise ValueError("Not supported data type " + str(
tensor_to_check_dtype))
raise ValueError("Not supported data type " + str(tensor_to_check_dtype)
+ ", tensor name : " + str(input_to_check))
def get_output():
sum = []
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import copy
from op_test import OpTest
import paddle
import paddle.fluid as fluid
from paddle.framework import core
from paddle.fluid.dygraph.base import switch_to_static_graph
paddle.enable_static()
class TestPutAlongAxisOp(OpTest):
def setUp(self):
self.init_data()
self.reduce_op = "assign"
self.dtype = 'float64'
self.op_type = "put_along_axis"
self.xnp = np.random.random(self.x_shape).astype(self.x_type)
# numpy put_along_axis is an inplace opearion.
self.xnp_result = copy.deepcopy(self.xnp)
np.put_along_axis(self.xnp_result, self.index, self.value, self.axis)
self.target = self.xnp_result
broadcast_shape_list = list(self.x_shape)
broadcast_shape_list[self.axis] = 1
self.braodcast_shape = tuple(broadcast_shape_list)
self.index_broadcast = np.broadcast_to(self.index, self.braodcast_shape)
self.value_broadcast = np.broadcast_to(self.value, self.braodcast_shape)
self.inputs = {
'Input': self.xnp,
'Index': self.index_broadcast,
'Value': self.value_broadcast
}
self.attrs = {'Axis': self.axis, 'Reduce': self.reduce_op}
self.outputs = {'Result': self.target}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["Input", "Value"], "Result")
def init_data(self):
self.x_type = "float64"
self.x_shape = (10, 10, 10)
self.value_type = "float64"
self.value = np.array([99]).astype(self.value_type)
self.index_type = "int32"
self.index = np.array([[[0]]]).astype(self.index_type)
self.axis = 1
self.axis_type = "int64"
class TestPutAlongAxisAPI(unittest.TestCase):
def setUp(self):
np.random.seed(0)
self.shape = [1, 3]
self.index_shape = [1, 1]
self.index_np = np.array([[0]]).astype('int64')
self.x_np = np.random.random(self.shape).astype(np.float32)
self.place = [paddle.CPUPlace()]
self.axis = 0
self.value_np = 99.0
self.value_shape = [1]
self.x_feed = copy.deepcopy(self.x_np)
if core.is_compiled_with_cuda():
self.place.append(paddle.CUDAPlace(0))
def test_api_static_case1(self):
paddle.enable_static()
def run(place):
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.fluid.data('X', self.shape)
index = paddle.fluid.data('Index', self.index_shape, "int64")
value = paddle.fluid.data('Value', self.value_shape)
out = paddle.put_along_axis(x, index, value, self.axis)
exe = paddle.static.Executor(self.place[0])
res = exe.run(feed={
'X': self.x_feed,
'Value': self.value_np,
'Index': self.index_np
},
fetch_list=[out])
np.put_along_axis(self.x_np, self.index_np, self.value_np,
self.axis)
# numpy put_along_axis is an inplace opearion.
out_ref = self.x_np
for out in res:
self.assertEqual(np.allclose(out, out_ref, rtol=1e-03), True)
for place in self.place:
run(place)
def test_api_dygraph_case1(self):
def run(place):
paddle.disable_static(place)
x_tensor = paddle.to_tensor(self.x_np)
index_tensor = paddle.to_tensor(self.index_np)
value_tensor = paddle.to_tensor(self.value_np)
out = paddle.put_along_axis(x_tensor, index_tensor, value_tensor,
self.axis)
np.array(
np.put_along_axis(self.x_np, self.index_np, self.value_np,
self.axis))
out_ref = self.x_np
self.assertEqual(
np.allclose(
out.numpy(), out_ref, rtol=1e-03), True)
# for ci coverage, numpy put_along_axis did not support argument of 'reduce'
paddle.put_along_axis(x_tensor, index_tensor, value_tensor,
self.axis, 'mul')
paddle.put_along_axis(x_tensor, index_tensor, value_tensor,
self.axis, 'add')
paddle.enable_static()
for place in self.place:
run(place)
def test_api_dygraph_case2(self):
def run(place):
paddle.disable_static(place)
self.shape = [2, 2]
self.index_shape = [2, 2]
self.index_np = np.array([[0, 0], [1, 0]]).astype('int64')
self.x_np = np.random.random(self.shape).astype(np.float32)
x_tensor = paddle.to_tensor(self.x_np)
index_tensor = paddle.to_tensor(self.index_np)
value_tensor = paddle.to_tensor(self.value_np)
out = paddle.put_along_axis(x_tensor, index_tensor, value_tensor,
self.axis)
np.array(
np.put_along_axis(self.x_np, self.index_np, self.value_np,
self.axis))
out_ref = self.x_np
self.assertEqual(
np.allclose(
out.numpy(), out_ref, rtol=1e-03), True)
paddle.enable_static()
for place in self.place:
run(place)
def test_inplace_dygraph_case3(self):
def run(place):
paddle.disable_static(place)
x_tensor = paddle.to_tensor(self.x_np)
index_tensor = paddle.to_tensor(self.index_np)
value_tensor = paddle.to_tensor(self.value_np)
x_tensor.put_along_axis_(index_tensor, value_tensor, self.axis)
np.array(
np.put_along_axis(self.x_np, self.index_np, self.value_np,
self.axis))
out_ref = self.x_np
self.assertEqual(
np.allclose(
x_tensor.numpy(), out_ref, rtol=1e-03), True)
paddle.enable_static()
for place in self.place:
run(place)
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
......@@ -16,7 +16,7 @@ from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest, skip_check_grad_ci
from op_test import OpTest
import paddle
import paddle.fluid as fluid
from paddle.framework import core
......
......@@ -120,6 +120,8 @@ from .manipulation import chunk # noqa: F401
from .manipulation import tensordot # noqa: F401
from .manipulation import as_complex # noqa: F401
from .manipulation import take_along_axis # noqa: F401
from .manipulation import put_along_axis # noqa: F401
from .manipulation import put_along_axis_ # noqa: F401
from .manipulation import as_real # noqa: F401
from .manipulation import moveaxis # noqa: F401
from .manipulation import repeat_interleave # noqa: F401
......@@ -473,6 +475,8 @@ tensor_method_func = [ #noqa
'moveaxis',
'repeat_interleave',
'take_along_axis',
'put_along_axis',
'put_along_axis_',
'exponential_',
]
......
......@@ -2756,9 +2756,9 @@ def take_along_axis(arr, indices, axis):
Take values from the input array by given indices matrix along the designated axis.
Args:
arr (Tensor) : The input Tensor. supported data type are float32 and float64.
arr (Tensor) : The input Tensor. Supported data types are float32 and float64.
indices (Tensor) : Indices to take along each 1d slice of arr. This must match the dimension of arr,
and need to broadcast against arr. Supported data type are int and int64.
and need to broadcast against arr. Supported data type are int and int64.
axis (int) : The axis to take 1d slices along.
Returns:
......@@ -2779,9 +2779,12 @@ def take_along_axis(arr, indices, axis):
print(result)
# [[1, 2, 3]]
"""
broadcast_shape_list = list(arr.shape)
broadcast_shape_list[axis] = 1
broadcast_shape = tuple(broadcast_shape_list)
if (arr.shape == indices.shape):
broadcast_shape = arr.shape
else:
broadcast_shape_list = list(arr.shape)
broadcast_shape_list[axis] = 1
broadcast_shape = tuple(broadcast_shape_list)
if in_dygraph_mode():
indices = paddle.broadcast_to(indices, broadcast_shape)
return _C_ops.take_along_axis(arr, indices, 'Axis', axis)
......@@ -2790,9 +2793,7 @@ def take_along_axis(arr, indices, axis):
'take_along_axis')
check_variable_and_dtype(indices, 'index', ['int32', 'int64'],
'take_along_axis')
indices = paddle.broadcast_to(
indices,
broadcast_shape) # broadcast to shape of the input array first.
indices = paddle.broadcast_to(indices, broadcast_shape)
helper = LayerHelper('take_along_axis', **locals())
dtype = helper.input_dtype()
result = helper.create_variable_for_type_inference(dtype)
......@@ -2803,3 +2804,90 @@ def take_along_axis(arr, indices, axis):
attrs={"Axis": axis},
outputs={"Result": result})
return result
def put_along_axis(arr, indices, values, axis, reduce='assign'):
"""
Put values into the destination array by given indices matrix along the designated axis.
Args:
arr (Tensor) : The Destination Tensor. Supported data types are float32 and float64.
indices (Tensor) : Indices to put along each 1d slice of arr. This must match the dimension of arr,
and need to broadcast against arr. Supported data type are int and int64.
axis (int) : The axis to put 1d slices along.
reduce (string | optinal) : The reduce operation, default is 'assign', support 'add', 'assign', 'mul' and 'multiply'.
Returns :
Tensor: The indexed element, same dtype with arr
Examples:
.. code-block:: python
import paddle
import numpy as np
x_np = np.array([[10, 30, 20], [60, 40, 50]])
index_np = np.array([[0]])
x = paddle.to_tensor(x_np)
index = paddle.to_tensor(index_np)
value = 99
axis = 0
result = paddle.put_along_axis(x, index, value, axis)
print(result)
# [[99, 99, 99],
# [60, 40, 50]]
"""
if (arr.shape == indices.shape):
broadcast_shape = arr.shape
else:
broadcast_shape_list = list(arr.shape)
broadcast_shape_list[axis] = 1
broadcast_shape = tuple(broadcast_shape_list)
if in_dygraph_mode():
indices = paddle.broadcast_to(indices, broadcast_shape)
values = paddle.to_tensor(values) if not isinstance(
values, paddle.Tensor) else values
values = paddle.broadcast_to(values, broadcast_shape)
return _C_ops.put_along_axis(arr, indices, values, "Axis", axis,
"Reduce", reduce)
check_variable_and_dtype(
arr, 'x', ['float16', 'float32', 'float64', 'int32', 'int64', 'uint8'],
'put_along_axis')
check_variable_and_dtype(indices, 'index', ['int32', 'int64'],
'put_along_axis')
indices = paddle.broadcast_to(indices, broadcast_shape)
values = paddle.broadcast_to(values, broadcast_shape)
helper = LayerHelper('put_along_axis', **locals())
dtype = helper.input_dtype()
result = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type="put_along_axis",
inputs={"Input": arr,
"Index": indices,
"Value": values},
attrs={"Axis": axis,
"Reduce": reduce},
outputs={"Result": result})
return result
@inplace_apis_in_dygraph_only
def put_along_axis_(arr, indices, values, axis, reduce='assign'):
r"""
Inplace version of ``put_along_axis`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_tensor_put_along_axis`.
"""
if (arr.shape == indices.shape):
broadcast_shape = arr.shape
else:
broadcast_shape_list = list(arr.shape)
broadcast_shape_list[axis] = 1
broadcast_shape = tuple(broadcast_shape_list)
indices = paddle.broadcast_to(indices, broadcast_shape)
values = paddle.to_tensor(values) if not isinstance(
values, paddle.Tensor) else values
values = paddle.broadcast_to(values, broadcast_shape)
return _C_ops.put_along_axis_(arr, indices, values, "Axis", axis, "Reduce",
reduce)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册