未验证 提交 1e1a4b9b 编写于 作者: F fuyou765 提交者: GitHub

[MLU] set_value performance optimizing (#44390)

上级 41f11d29
...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <numeric>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h" #include "paddle/fluid/operators/mlu/mlu_baseop.h"
#include "paddle/fluid/operators/set_value_op.h" #include "paddle/fluid/operators/set_value_op.h"
...@@ -62,7 +63,6 @@ class SetValueMLUKernel : public framework::OpKernel<T> { ...@@ -62,7 +63,6 @@ class SetValueMLUKernel : public framework::OpKernel<T> {
auto slice_dims_for_assign = decrease_slice_dims; auto slice_dims_for_assign = decrease_slice_dims;
if (!none_axes.empty()) { if (!none_axes.empty()) {
std::vector<int64_t> slice_dims_with_none; std::vector<int64_t> slice_dims_with_none;
size_t none_axes_cur = 0, decrease_axes_cur = 0; size_t none_axes_cur = 0, decrease_axes_cur = 0;
for (int i = 0; i < slice_dims.size(); ++i) { for (int i = 0; i < slice_dims.size(); ++i) {
while (none_axes_cur < none_axes.size() && while (none_axes_cur < none_axes.size() &&
...@@ -84,51 +84,22 @@ class SetValueMLUKernel : public framework::OpKernel<T> { ...@@ -84,51 +84,22 @@ class SetValueMLUKernel : public framework::OpKernel<T> {
slice_dims_for_assign = phi::make_ddim(slice_dims_with_none); slice_dims_for_assign = phi::make_ddim(slice_dims_with_none);
} }
int in_size = in_dims.size();
auto starts_indices = std::vector<int64_t>(in_dims.size(), 0); int starts_indices[in_size] = {0};
auto ends_indices = std::vector<int64_t>(in_dims.size(), 0); int ends_indices[in_size] = {0};
auto strides_indices = std::vector<int64_t>(in_dims.size(), 0); int strides_indices[in_size] = {0};
for (int i = 0; i < in_dims.size(); ++i) { for (int i = 0; i < in_dims.size(); ++i) {
starts_indices[i] = 0; starts_indices[i] = 0;
ends_indices[i] = slice_dims[i]; ends_indices[i] = static_cast<int>(slice_dims[i]);
strides_indices[i] = 1; strides_indices[i] = 1;
} }
for (size_t i = 0; i < axes.size(); i++) { for (size_t i = 0; i < axes.size(); i++) {
int axis_index = axes[i]; int axis_index = axes[i];
starts_indices[axis_index] = starts[i]; starts_indices[axis_index] = static_cast<int>(starts[i]);
ends_indices[axis_index] = ends[i]; ends_indices[axis_index] = static_cast<int>(ends[i]);
strides_indices[axis_index] = steps[i]; strides_indices[axis_index] = static_cast<int>(steps[i]);
}
int64_t stride_step = phi::product(in_dims);
std::vector<int64_t> index_indices(1, 0);
for (size_t i = 0; i < strides_indices.size(); ++i) {
auto index_size = index_indices.size();
stride_step /= in_dims[i];
for (size_t j = 0; j < index_size; ++j) {
auto start_index = *index_indices.begin();
if (strides_indices[i] > 0) {
for (int64_t k = starts_indices[i]; k < ends_indices[i];
k += strides_indices[i]) {
index_indices.push_back(start_index + k * stride_step);
}
} else {
for (int64_t k = starts_indices[i]; k > ends_indices[i];
k += strides_indices[i]) {
index_indices.push_back(start_index + k * stride_step);
}
}
index_indices.erase(index_indices.begin());
}
} }
PADDLE_ENFORCE_EQ(
static_cast<int64_t>(index_indices.size()),
phi::product(slice_dims_for_assign),
platform::errors::InvalidArgument(
"OP(set_value) error index indices and value update not match "));
Tensor value_t(in->type()); Tensor value_t(in->type());
if (value_tensor != nullptr) { if (value_tensor != nullptr) {
value_t.ShareDataWith(*value_tensor); value_t.ShareDataWith(*value_tensor);
...@@ -160,29 +131,71 @@ class SetValueMLUKernel : public framework::OpKernel<T> { ...@@ -160,29 +131,71 @@ class SetValueMLUKernel : public framework::OpKernel<T> {
int64_t input_numel = phi::product(in_dims); int64_t input_numel = phi::product(in_dims);
int64_t value_numel = phi::product(value_temp.dims()); int64_t value_numel = phi::product(value_temp.dims());
Tensor in_temp, out_temp, val_temp; Tensor in_temp, out_temp, val_temp, index_out;
int64_t stride_step = phi::product(in_dims);
std::vector<int64_t> index_indices(stride_step);
std::iota(index_indices.begin(), index_indices.end(), 0);
framework::Tensor index_temp; framework::Tensor index_temp;
in_temp.ShareDataWith(*in); in_temp.ShareDataWith(*in);
val_temp.ShareDataWith(value_temp); val_temp.ShareDataWith(value_temp);
paddle::framework::TensorFromVector( paddle::framework::TensorFromVector(
index_indices, ctx.device_context(), &index_temp); index_indices, ctx.device_context(), &index_temp);
index_temp.Resize(in_dims);
auto index_dims = in_dims;
for (int i = 0; i < in_dims.size(); ++i) {
if (starts_indices[i] < 0 || ends_indices[i] < 0) {
starts_indices[i] -= in_dims[i];
ends_indices[i] -= in_dims[i];
}
if (strides_indices[i] > 0)
index_dims[i] =
static_cast<int>((ends_indices[i] - starts_indices[i] - 1) /
strides_indices[i]) +
1;
else
index_dims[i] =
static_cast<int>((ends_indices[i] - starts_indices[i] + 1) /
strides_indices[i]) +
1;
}
auto new_in_dims = phi::make_ddim({input_numel}); auto new_in_dims = phi::make_ddim({input_numel});
auto new_val_dims = phi::make_ddim({value_numel}); auto new_val_dims = phi::make_ddim({value_numel});
in_temp.Resize(new_in_dims); in_temp.Resize(new_in_dims);
val_temp.Resize(new_val_dims); val_temp.Resize(new_val_dims);
index_out.Resize(index_dims);
index_out.mutable_data<int64_t>(ctx.GetPlace());
cnnlScatterRefMode_t mode = CNNL_SCATTERREF_UPDATE; cnnlScatterRefMode_t mode = CNNL_SCATTERREF_UPDATE;
MLUCnnlTensorDesc x_desc(in_temp); MLUCnnlTensorDesc x_desc(in_temp);
MLUCnnlTensorDesc indices_desc(index_temp); MLUCnnlTensorDesc indices_desc(index_temp);
MLUCnnlTensorDesc indices_out_desc(index_out);
MLUCnnlTensorDesc updates_desc(val_temp); MLUCnnlTensorDesc updates_desc(val_temp);
MLUCnnlTensorDesc out_desc(*out); MLUCnnlTensorDesc out_desc(*out);
MLUCnnl::StridedSlice(ctx,
starts_indices,
ends_indices,
strides_indices,
indices_desc.get(),
GetBasePtr(&index_temp),
indices_out_desc.get(),
GetBasePtr(&index_out));
PADDLE_ENFORCE_EQ(
static_cast<int64_t>(phi::product(index_out.dims())),
phi::product(slice_dims_for_assign),
platform::errors::InvalidArgument(
"OP(set_value) error index indices and value update not match "));
Tensor index_final;
index_final.ShareDataWith(index_out);
int64_t indices_numel = phi::product(index_dims);
auto new_index_dims = phi::make_ddim({indices_numel});
index_final.Resize(new_index_dims);
MLUCnnlTensorDesc indices_final_desc(index_final);
MLUCnnl::ScatterRefFunctor(ctx, MLUCnnl::ScatterRefFunctor(ctx,
x_desc.get(), x_desc.get(),
GetBasePtr(&in_temp), GetBasePtr(&in_temp),
updates_desc.get(), updates_desc.get(),
GetBasePtr(&val_temp), GetBasePtr(&val_temp),
indices_desc.get(), indices_final_desc.get(),
GetBasePtr(&index_temp), GetBasePtr(&index_final),
mode); mode);
in_temp.Resize(in_dims); in_temp.Resize(in_dims);
paddle::framework::TensorCopy(in_temp, ctx.GetPlace(), out); paddle::framework::TensorCopy(in_temp, ctx.GetPlace(), out);
......
...@@ -127,6 +127,18 @@ class TestSetValueItemSlice4(TestSetValueApi): ...@@ -127,6 +127,18 @@ class TestSetValueItemSlice4(TestSetValueApi):
self.data[0:, 1:2, :] = self.value self.data[0:, 1:2, :] = self.value
class TestSetValueItemSlice5(TestSetValueApi):
def set_shape(self):
self.shape = [100, 426, 640]
def _call_setitem(self, x):
x[0:-1] = self.value
def _get_answer(self):
self.data[0:-1] = self.value
#TODO: Fix this after MLU support while_loop #TODO: Fix this after MLU support while_loop
#class TestSetValueItemSliceInWhile(TestSetValueApi): #class TestSetValueItemSliceInWhile(TestSetValueApi):
# def _call_setitem(self, x): # def _call_setitem(self, x):
...@@ -517,6 +529,7 @@ create_test_value_int32(TestSetValueItemSlice) ...@@ -517,6 +529,7 @@ create_test_value_int32(TestSetValueItemSlice)
create_test_value_int32(TestSetValueItemSlice2) create_test_value_int32(TestSetValueItemSlice2)
create_test_value_int32(TestSetValueItemSlice3) create_test_value_int32(TestSetValueItemSlice3)
create_test_value_int32(TestSetValueItemSlice4) create_test_value_int32(TestSetValueItemSlice4)
create_test_value_int32(TestSetValueItemSlice5)
def create_test_value_tensor_fp32(parent): def create_test_value_tensor_fp32(parent):
...@@ -543,6 +556,7 @@ create_test_value_tensor_fp32(TestSetValueItemSlice) ...@@ -543,6 +556,7 @@ create_test_value_tensor_fp32(TestSetValueItemSlice)
create_test_value_tensor_fp32(TestSetValueItemSlice2) create_test_value_tensor_fp32(TestSetValueItemSlice2)
create_test_value_tensor_fp32(TestSetValueItemSlice3) create_test_value_tensor_fp32(TestSetValueItemSlice3)
create_test_value_tensor_fp32(TestSetValueItemSlice4) create_test_value_tensor_fp32(TestSetValueItemSlice4)
create_test_value_tensor_fp32(TestSetValueItemSlice5)
# 3. Test different shape of value # 3. Test different shape of value
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册