未验证 提交 b4b31f47 编写于 作者: C Chengmo 提交者: GitHub

Update index sample (#24109) (#24162)

* update index sample
上级 343464df
...@@ -142,13 +142,14 @@ REGISTER_OPERATOR(index_sample, ops::IndexSampleOp, ops::IndexSampleOpMaker, ...@@ -142,13 +142,14 @@ REGISTER_OPERATOR(index_sample, ops::IndexSampleOp, ops::IndexSampleOpMaker,
REGISTER_OPERATOR(index_sample_grad, ops::IndexSampleGradOp, REGISTER_OPERATOR(index_sample_grad, ops::IndexSampleGradOp,
ops::IndexSampleGradNoNeedBufferVarInferer); ops::IndexSampleGradNoNeedBufferVarInferer);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
index_sample, ops::IndexSampleKernel<paddle::platform::CPUPlace, float>, index_sample,
ops::IndexSampleKernel<paddle::platform::CPUPlace, double>, ops::IndexSampleKernel<paddle::platform::CPUDeviceContext, float>,
ops::IndexSampleKernel<paddle::platform::CPUPlace, int>, ops::IndexSampleKernel<paddle::platform::CPUDeviceContext, double>,
ops::IndexSampleKernel<paddle::platform::CPUPlace, int64_t>); ops::IndexSampleKernel<paddle::platform::CPUDeviceContext, int>,
ops::IndexSampleKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
index_sample_grad, index_sample_grad,
ops::IndexSampleGradKernel<paddle::platform::CPUPlace, float>, ops::IndexSampleGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::IndexSampleGradKernel<paddle::platform::CPUPlace, double>, ops::IndexSampleGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::IndexSampleGradKernel<paddle::platform::CPUPlace, int>, ops::IndexSampleGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::IndexSampleGradKernel<paddle::platform::CPUPlace, int64_t>); ops::IndexSampleGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/index_sample_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
index_sample,
ops::IndexSampleKernel<paddle::platform::CUDADeviceContext, float>,
ops::IndexSampleKernel<paddle::platform::CUDADeviceContext, double>,
ops::IndexSampleKernel<paddle::platform::CUDADeviceContext, int>,
ops::IndexSampleKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
index_sample_grad,
ops::IndexSampleGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::IndexSampleGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::IndexSampleGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::IndexSampleGradKernel<paddle::platform::CUDADeviceContext, int64_t>);
...@@ -41,39 +41,41 @@ void IndexSampleInner(const framework::ExecutionContext &context, ...@@ -41,39 +41,41 @@ void IndexSampleInner(const framework::ExecutionContext &context,
auto value_length = input_dims[1]; auto value_length = input_dims[1];
auto index_length = index_dims[1]; auto index_length = index_dims[1];
int index_ids_num = index.numel(); int index_ids_num = index.numel();
auto *input_data = input.data<T>();
auto *index_data = index.data<IndexT>();
std::vector<T> res{}; std::vector<T> input_vec;
std::vector<IndexT> index_vec;
TensorToVector(input, context.device_context(), &input_vec);
TensorToVector(index, context.device_context(), &index_vec);
std::vector<T> res(index_ids_num);
for (int i = 0; i < index_ids_num; i++) { for (int i = 0; i < index_ids_num; i++) {
int b = floor(i / index_length); int b = floor(i / index_length);
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
index_data[i], 0, index_vec[i], 0,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Variable value (index) of OP(index_sample) " "Variable value (index) of OP(index_sample) "
"expected >= 0 and < %ld, but got %ld. Please check input " "expected >= 0 and < %ld, but got %ld. Please check input "
"value.", "value.",
value_length, index_data[i])); value_length, index_vec[i]));
PADDLE_ENFORCE_LT( PADDLE_ENFORCE_LT(
index_data[i], value_length, index_vec[i], value_length,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Variable value (index) of OP(index_sample) " "Variable value (index) of OP(index_sample) "
"expected >= 0 and < %ld, but got %ld. Please check input " "expected >= 0 and < %ld, but got %ld. Please check input "
"value.", "value.",
value_length, index_data[i])); value_length, index_vec[i]));
int v_i = b * value_length + static_cast<int>(index_data[i]); int v_i = b * value_length + static_cast<int>(index_vec[i]);
T v = input_data[v_i]; T v = input_vec[v_i];
VLOG(4) << "Index Sample: batch = " << b << " index = " << v_i VLOG(4) << "Index Sample: batch = " << b << " index = " << v_i
<< " value = " << v; << " value = " << v;
res.push_back(v); res[i] = v;
} }
auto ddim = framework::make_ddim({batch_size, index_length}); auto ddim = framework::make_ddim({batch_size, index_length});
output->mutable_data<T>(context.GetPlace());
framework::TensorFromVector(res, context.device_context(), output);
output->Resize(ddim); output->Resize(ddim);
T *out_data = output->mutable_data<T>(context.GetPlace());
memcpy(out_data, &res[0], sizeof(T) * index_ids_num);
} }
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
...@@ -113,39 +115,42 @@ template <typename T, typename IndexT = int> ...@@ -113,39 +115,42 @@ template <typename T, typename IndexT = int>
void IndexSampleGradInner(const framework::ExecutionContext &context, void IndexSampleGradInner(const framework::ExecutionContext &context,
const LoDTensor &out_grad, const LoDTensor &index, const LoDTensor &out_grad, const LoDTensor &index,
LoDTensor *x_grad) { LoDTensor *x_grad) {
std::vector<T> out_grad_vec;
std::vector<IndexT> index_vec;
TensorToVector(out_grad, context.device_context(), &out_grad_vec);
TensorToVector(index, context.device_context(), &index_vec);
auto index_dims = index.dims(); auto index_dims = index.dims();
auto x_grad_dims = x_grad->dims(); auto x_grad_dims = x_grad->dims();
int batch_size = x_grad_dims[0];
auto value_length = x_grad_dims[1]; auto value_length = x_grad_dims[1];
auto index_length = index_dims[1]; auto index_length = index_dims[1];
int index_ids_num = index.numel(); int index_ids_num = index.numel();
T *x_grad_data = x_grad->mutable_data<T>(context.GetPlace()); std::vector<T> x_grad_vec(x_grad->numel(), 0);
auto *out_grad_data = out_grad.data<T>();
auto *index_data = index.data<IndexT>();
memset(x_grad_data, 0, batch_size * value_length * sizeof(T));
for (int i = 0; i < index_ids_num; i++) { for (int i = 0; i < index_ids_num; i++) {
int b = floor(i / index_length); int b = floor(i / index_length);
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
index_data[i], 0, index_vec[i], 0,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Variable value (index) of OP(index_sample_grad) " "Variable value (index) of OP(index_sample_grad) "
"expected >= 0 and < %ld, but got %ld. Please check input " "expected >= 0 and < %ld, but got %ld. Please check input "
"value.", "value.",
value_length, index_data[i])); value_length, index_vec[i]));
PADDLE_ENFORCE_LT( PADDLE_ENFORCE_LT(
index_data[i], value_length, index_vec[i], value_length,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Variable value (index) of OP(index_sample_grad) " "Variable value (index) of OP(index_sample_grad) "
"expected >= 0 and < %ld, but got %ld. Please check input " "expected >= 0 and < %ld, but got %ld. Please check input "
"value.", "value.",
value_length, index_data[i])); value_length, index_vec[i]));
int v_i = b * value_length + static_cast<int>(index_data[i]); int v_i = b * value_length + static_cast<int>(index_vec[i]);
x_grad_data[v_i] += out_grad_data[i]; x_grad_vec[v_i] += out_grad_vec[i];
} }
x_grad->mutable_data<T>(context.GetPlace());
framework::TensorFromVector(x_grad_vec, context.device_context(), x_grad);
x_grad->Resize(x_grad_dims);
} }
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
......
...@@ -32,6 +32,7 @@ class TestIndexSampleOp(OpTest): ...@@ -32,6 +32,7 @@ class TestIndexSampleOp(OpTest):
for i in range(self.index_shape[0]): for i in range(self.index_shape[0]):
for j in indexnp[i]: for j in indexnp[i]:
index_array.append(xnp[i, j]) index_array.append(xnp[i, j])
index_array = np.array(index_array).astype(self.x_type)
out = np.reshape(index_array, self.index_shape) out = np.reshape(index_array, self.index_shape)
self.outputs = {'Out': out} self.outputs = {'Out': out}
......
...@@ -475,21 +475,48 @@ def index_sample(x, index): ...@@ -475,21 +475,48 @@ def index_sample(x, index):
import paddle.fluid as fluid import paddle.fluid as fluid
import numpy as np import numpy as np
# create x value data = np.array([[1.0, 2.0, 3.0, 4.0],
x_shape = (2, 5) [5.0, 6.0, 7.0, 8.0],
x_type = "float64" [9.0, 10.0, 11.0, 12.0]]).astype('float32')
x_np = np.random.random(x_shape).astype(x_type)
data_index = np.array([[0, 1, 2],
# create index value [1, 2, 3],
index_shape = (2, 3) [0, 0, 0]]).astype('int32')
index_type = "int32"
index_np = np.random.randint(low=0, target_data = np.array([[100, 200, 300, 400],
high=x_shape[1], [500, 600, 700, 800],
size=index_shape).astype(index_type) [900, 1000, 1100, 1200]]).astype('int32')
x = fluid.data(name='x', shape=[-1, 5], dtype='float64') with fluid.dygraph.guard():
index = fluid.data(name='index', shape=[-1, 3], dtype='int32') x = fluid.dygraph.to_variable(data)
output = paddle.index_sample(x=x, index=index) index = fluid.dygraph.to_variable(data_index)
target = fluid.dygraph.to_variable(target_data)
out_z1 = paddle.index_sample(x, index)
print(out_z1.numpy())
#[[1. 2. 3.]
# [6. 7. 8.]
# [9. 9. 9.]]
# Use the index of the maximum value by topk op
# get the value of the element of the corresponding index in other tensors
top_value, top_index = fluid.layers.topk(x, k=2)
out_z2 = paddle.index_sample(target, top_index)
print(top_value.numpy())
#[[ 4. 3.]
# [ 8. 7.]
# [12. 11.]]
print(top_index.numpy())
#[[3 2]
# [3 2]
# [3 2]]
print(out_z2.numpy())
#[[ 400 300]
# [ 800 700]
# [1200 1100]]
""" """
helper = LayerHelper("index_sample", **locals()) helper = LayerHelper("index_sample", **locals())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册