diff --git a/paddle/fluid/operators/set_value_op_mlu.cc b/paddle/fluid/operators/set_value_op_mlu.cc index 44422994f60daa38f1b0cf5877649a23ce7a132f..9a6277dfa2312ba9ba83f16e433f258775446d26 100644 --- a/paddle/fluid/operators/set_value_op_mlu.cc +++ b/paddle/fluid/operators/set_value_op_mlu.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/mlu/mlu_baseop.h" #include "paddle/fluid/operators/set_value_op.h" @@ -62,7 +63,6 @@ class SetValueMLUKernel : public framework::OpKernel { auto slice_dims_for_assign = decrease_slice_dims; if (!none_axes.empty()) { std::vector slice_dims_with_none; - size_t none_axes_cur = 0, decrease_axes_cur = 0; for (int i = 0; i < slice_dims.size(); ++i) { while (none_axes_cur < none_axes.size() && @@ -84,51 +84,22 @@ class SetValueMLUKernel : public framework::OpKernel { slice_dims_for_assign = phi::make_ddim(slice_dims_with_none); } - - auto starts_indices = std::vector(in_dims.size(), 0); - auto ends_indices = std::vector(in_dims.size(), 0); - auto strides_indices = std::vector(in_dims.size(), 0); + int in_size = in_dims.size(); + int starts_indices[in_size] = {0}; + int ends_indices[in_size] = {0}; + int strides_indices[in_size] = {0}; for (int i = 0; i < in_dims.size(); ++i) { starts_indices[i] = 0; - ends_indices[i] = slice_dims[i]; + ends_indices[i] = static_cast(slice_dims[i]); strides_indices[i] = 1; } for (size_t i = 0; i < axes.size(); i++) { int axis_index = axes[i]; - starts_indices[axis_index] = starts[i]; - ends_indices[axis_index] = ends[i]; - strides_indices[axis_index] = steps[i]; - } - - int64_t stride_step = phi::product(in_dims); - std::vector index_indices(1, 0); - for (size_t i = 0; i < strides_indices.size(); ++i) { - auto index_size = index_indices.size(); - stride_step /= in_dims[i]; - for (size_t j = 0; j < index_size; ++j) { - auto start_index = *index_indices.begin(); - if (strides_indices[i] > 0) { - for (int64_t k = starts_indices[i]; k < ends_indices[i]; - k += strides_indices[i]) { - index_indices.push_back(start_index + k * stride_step); - } - } else { - for (int64_t k = starts_indices[i]; k > ends_indices[i]; - k += strides_indices[i]) { - index_indices.push_back(start_index + k * stride_step); - } - } - index_indices.erase(index_indices.begin()); - } + starts_indices[axis_index] = static_cast(starts[i]); + ends_indices[axis_index] = static_cast(ends[i]); + strides_indices[axis_index] = static_cast(steps[i]); } - - PADDLE_ENFORCE_EQ( - static_cast(index_indices.size()), - phi::product(slice_dims_for_assign), - platform::errors::InvalidArgument( - "OP(set_value) error index indices and value update not match ")); - Tensor value_t(in->type()); if (value_tensor != nullptr) { value_t.ShareDataWith(*value_tensor); @@ -160,29 +131,71 @@ class SetValueMLUKernel : public framework::OpKernel { int64_t input_numel = phi::product(in_dims); int64_t value_numel = phi::product(value_temp.dims()); - Tensor in_temp, out_temp, val_temp; + Tensor in_temp, out_temp, val_temp, index_out; + int64_t stride_step = phi::product(in_dims); + std::vector index_indices(stride_step); + std::iota(index_indices.begin(), index_indices.end(), 0); framework::Tensor index_temp; in_temp.ShareDataWith(*in); val_temp.ShareDataWith(value_temp); paddle::framework::TensorFromVector( index_indices, ctx.device_context(), &index_temp); + index_temp.Resize(in_dims); + auto index_dims = in_dims; + for (int i = 0; i < in_dims.size(); ++i) { + if (starts_indices[i] < 0 || ends_indices[i] < 0) { + starts_indices[i] -= in_dims[i]; + ends_indices[i] -= in_dims[i]; + } + if (strides_indices[i] > 0) + index_dims[i] = + static_cast((ends_indices[i] - starts_indices[i] - 1) / + strides_indices[i]) + + 1; + else + index_dims[i] = + static_cast((ends_indices[i] - starts_indices[i] + 1) / + strides_indices[i]) + + 1; + } auto new_in_dims = phi::make_ddim({input_numel}); auto new_val_dims = phi::make_ddim({value_numel}); in_temp.Resize(new_in_dims); val_temp.Resize(new_val_dims); + index_out.Resize(index_dims); + index_out.mutable_data(ctx.GetPlace()); cnnlScatterRefMode_t mode = CNNL_SCATTERREF_UPDATE; MLUCnnlTensorDesc x_desc(in_temp); MLUCnnlTensorDesc indices_desc(index_temp); + MLUCnnlTensorDesc indices_out_desc(index_out); MLUCnnlTensorDesc updates_desc(val_temp); MLUCnnlTensorDesc out_desc(*out); - + MLUCnnl::StridedSlice(ctx, + starts_indices, + ends_indices, + strides_indices, + indices_desc.get(), + GetBasePtr(&index_temp), + indices_out_desc.get(), + GetBasePtr(&index_out)); + PADDLE_ENFORCE_EQ( + static_cast(phi::product(index_out.dims())), + phi::product(slice_dims_for_assign), + platform::errors::InvalidArgument( + "OP(set_value) error index indices and value update not match ")); + Tensor index_final; + index_final.ShareDataWith(index_out); + int64_t indices_numel = phi::product(index_dims); + auto new_index_dims = phi::make_ddim({indices_numel}); + index_final.Resize(new_index_dims); + MLUCnnlTensorDesc indices_final_desc(index_final); MLUCnnl::ScatterRefFunctor(ctx, x_desc.get(), GetBasePtr(&in_temp), updates_desc.get(), GetBasePtr(&val_temp), - indices_desc.get(), - GetBasePtr(&index_temp), + indices_final_desc.get(), + GetBasePtr(&index_final), mode); in_temp.Resize(in_dims); paddle::framework::TensorCopy(in_temp, ctx.GetPlace(), out); diff --git a/python/paddle/fluid/tests/unittests/mlu/test_set_value_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_set_value_op_mlu.py index f6183687f6a47c74b314df69ec902812516af734..1842f9a2f632c2db4fb3e75d5dc3498f7793586c 100644 --- a/python/paddle/fluid/tests/unittests/mlu/test_set_value_op_mlu.py +++ b/python/paddle/fluid/tests/unittests/mlu/test_set_value_op_mlu.py @@ -127,6 +127,18 @@ class TestSetValueItemSlice4(TestSetValueApi): self.data[0:, 1:2, :] = self.value +class TestSetValueItemSlice5(TestSetValueApi): + + def set_shape(self): + self.shape = [100, 426, 640] + + def _call_setitem(self, x): + x[0:-1] = self.value + + def _get_answer(self): + self.data[0:-1] = self.value + + #TODO: Fix this after MLU support while_loop #class TestSetValueItemSliceInWhile(TestSetValueApi): # def _call_setitem(self, x): @@ -517,6 +529,7 @@ create_test_value_int32(TestSetValueItemSlice) create_test_value_int32(TestSetValueItemSlice2) create_test_value_int32(TestSetValueItemSlice3) create_test_value_int32(TestSetValueItemSlice4) +create_test_value_int32(TestSetValueItemSlice5) def create_test_value_tensor_fp32(parent): @@ -543,6 +556,7 @@ create_test_value_tensor_fp32(TestSetValueItemSlice) create_test_value_tensor_fp32(TestSetValueItemSlice2) create_test_value_tensor_fp32(TestSetValueItemSlice3) create_test_value_tensor_fp32(TestSetValueItemSlice4) +create_test_value_tensor_fp32(TestSetValueItemSlice5) # 3. Test different shape of value