未验证 提交 1e1a4b9b 编写于 作者: F fuyou765 提交者: GitHub

[MLU] set_value performance optimizing (#44390)

上级 41f11d29
......@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <numeric>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
#include "paddle/fluid/operators/set_value_op.h"
......@@ -62,7 +63,6 @@ class SetValueMLUKernel : public framework::OpKernel<T> {
auto slice_dims_for_assign = decrease_slice_dims;
if (!none_axes.empty()) {
std::vector<int64_t> slice_dims_with_none;
size_t none_axes_cur = 0, decrease_axes_cur = 0;
for (int i = 0; i < slice_dims.size(); ++i) {
while (none_axes_cur < none_axes.size() &&
......@@ -84,51 +84,22 @@ class SetValueMLUKernel : public framework::OpKernel<T> {
slice_dims_for_assign = phi::make_ddim(slice_dims_with_none);
}
auto starts_indices = std::vector<int64_t>(in_dims.size(), 0);
auto ends_indices = std::vector<int64_t>(in_dims.size(), 0);
auto strides_indices = std::vector<int64_t>(in_dims.size(), 0);
int in_size = in_dims.size();
int starts_indices[in_size] = {0};
int ends_indices[in_size] = {0};
int strides_indices[in_size] = {0};
for (int i = 0; i < in_dims.size(); ++i) {
starts_indices[i] = 0;
ends_indices[i] = slice_dims[i];
ends_indices[i] = static_cast<int>(slice_dims[i]);
strides_indices[i] = 1;
}
for (size_t i = 0; i < axes.size(); i++) {
int axis_index = axes[i];
starts_indices[axis_index] = starts[i];
ends_indices[axis_index] = ends[i];
strides_indices[axis_index] = steps[i];
}
int64_t stride_step = phi::product(in_dims);
std::vector<int64_t> index_indices(1, 0);
for (size_t i = 0; i < strides_indices.size(); ++i) {
auto index_size = index_indices.size();
stride_step /= in_dims[i];
for (size_t j = 0; j < index_size; ++j) {
auto start_index = *index_indices.begin();
if (strides_indices[i] > 0) {
for (int64_t k = starts_indices[i]; k < ends_indices[i];
k += strides_indices[i]) {
index_indices.push_back(start_index + k * stride_step);
}
} else {
for (int64_t k = starts_indices[i]; k > ends_indices[i];
k += strides_indices[i]) {
index_indices.push_back(start_index + k * stride_step);
}
}
index_indices.erase(index_indices.begin());
}
starts_indices[axis_index] = static_cast<int>(starts[i]);
ends_indices[axis_index] = static_cast<int>(ends[i]);
strides_indices[axis_index] = static_cast<int>(steps[i]);
}
PADDLE_ENFORCE_EQ(
static_cast<int64_t>(index_indices.size()),
phi::product(slice_dims_for_assign),
platform::errors::InvalidArgument(
"OP(set_value) error index indices and value update not match "));
Tensor value_t(in->type());
if (value_tensor != nullptr) {
value_t.ShareDataWith(*value_tensor);
......@@ -160,29 +131,71 @@ class SetValueMLUKernel : public framework::OpKernel<T> {
int64_t input_numel = phi::product(in_dims);
int64_t value_numel = phi::product(value_temp.dims());
Tensor in_temp, out_temp, val_temp;
Tensor in_temp, out_temp, val_temp, index_out;
int64_t stride_step = phi::product(in_dims);
std::vector<int64_t> index_indices(stride_step);
std::iota(index_indices.begin(), index_indices.end(), 0);
framework::Tensor index_temp;
in_temp.ShareDataWith(*in);
val_temp.ShareDataWith(value_temp);
paddle::framework::TensorFromVector(
index_indices, ctx.device_context(), &index_temp);
index_temp.Resize(in_dims);
auto index_dims = in_dims;
for (int i = 0; i < in_dims.size(); ++i) {
if (starts_indices[i] < 0 || ends_indices[i] < 0) {
starts_indices[i] -= in_dims[i];
ends_indices[i] -= in_dims[i];
}
if (strides_indices[i] > 0)
index_dims[i] =
static_cast<int>((ends_indices[i] - starts_indices[i] - 1) /
strides_indices[i]) +
1;
else
index_dims[i] =
static_cast<int>((ends_indices[i] - starts_indices[i] + 1) /
strides_indices[i]) +
1;
}
auto new_in_dims = phi::make_ddim({input_numel});
auto new_val_dims = phi::make_ddim({value_numel});
in_temp.Resize(new_in_dims);
val_temp.Resize(new_val_dims);
index_out.Resize(index_dims);
index_out.mutable_data<int64_t>(ctx.GetPlace());
cnnlScatterRefMode_t mode = CNNL_SCATTERREF_UPDATE;
MLUCnnlTensorDesc x_desc(in_temp);
MLUCnnlTensorDesc indices_desc(index_temp);
MLUCnnlTensorDesc indices_out_desc(index_out);
MLUCnnlTensorDesc updates_desc(val_temp);
MLUCnnlTensorDesc out_desc(*out);
MLUCnnl::StridedSlice(ctx,
starts_indices,
ends_indices,
strides_indices,
indices_desc.get(),
GetBasePtr(&index_temp),
indices_out_desc.get(),
GetBasePtr(&index_out));
PADDLE_ENFORCE_EQ(
static_cast<int64_t>(phi::product(index_out.dims())),
phi::product(slice_dims_for_assign),
platform::errors::InvalidArgument(
"OP(set_value) error index indices and value update not match "));
Tensor index_final;
index_final.ShareDataWith(index_out);
int64_t indices_numel = phi::product(index_dims);
auto new_index_dims = phi::make_ddim({indices_numel});
index_final.Resize(new_index_dims);
MLUCnnlTensorDesc indices_final_desc(index_final);
MLUCnnl::ScatterRefFunctor(ctx,
x_desc.get(),
GetBasePtr(&in_temp),
updates_desc.get(),
GetBasePtr(&val_temp),
indices_desc.get(),
GetBasePtr(&index_temp),
indices_final_desc.get(),
GetBasePtr(&index_final),
mode);
in_temp.Resize(in_dims);
paddle::framework::TensorCopy(in_temp, ctx.GetPlace(), out);
......
......@@ -127,6 +127,18 @@ class TestSetValueItemSlice4(TestSetValueApi):
self.data[0:, 1:2, :] = self.value
class TestSetValueItemSlice5(TestSetValueApi):
def set_shape(self):
self.shape = [100, 426, 640]
def _call_setitem(self, x):
x[0:-1] = self.value
def _get_answer(self):
self.data[0:-1] = self.value
#TODO: Fix this after MLU support while_loop
#class TestSetValueItemSliceInWhile(TestSetValueApi):
# def _call_setitem(self, x):
......@@ -517,6 +529,7 @@ create_test_value_int32(TestSetValueItemSlice)
create_test_value_int32(TestSetValueItemSlice2)
create_test_value_int32(TestSetValueItemSlice3)
create_test_value_int32(TestSetValueItemSlice4)
create_test_value_int32(TestSetValueItemSlice5)
def create_test_value_tensor_fp32(parent):
......@@ -543,6 +556,7 @@ create_test_value_tensor_fp32(TestSetValueItemSlice)
create_test_value_tensor_fp32(TestSetValueItemSlice2)
create_test_value_tensor_fp32(TestSetValueItemSlice3)
create_test_value_tensor_fp32(TestSetValueItemSlice4)
create_test_value_tensor_fp32(TestSetValueItemSlice5)
# 3. Test different shape of value
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册