未验证 提交 239dbc4e 编写于 作者: J JYChen 提交者: GitHub

fix the set_value error in cpu (#49804)

* fix the set_value error in cpu

* add a unitest for set_value OP

* fix platform::is_gpu_place

* add todo note for set_value
上级 aba6af4f
......@@ -136,7 +136,6 @@ void SetValueImpl(const Context& dev_ctx,
Empty<T>(dev_ctx, IntArray{slice_dims.Get(), slice_dims.size()});
DenseTensor pad_tensor =
Empty<T>(dev_ctx, IntArray{in_dims.Get(), in_dims.size()});
auto pad_e = EigenTensor<T, RANK>::From(pad_tensor, in_dims);
auto out_e = EigenTensor<T, RANK>::From(*out);
auto slice_e = EigenTensor<T, RANK>::From(slice_tensor, slice_dims);
......@@ -185,16 +184,38 @@ void SetValueImpl(const Context& dev_ctx,
// is [3], which is right.
slice_tensor.Resize(slice_dims_for_assign);
CheckIsDimsMatch(slice_dims_for_assign, value.dims());
// ElementwiseComputeEx can do broadcasting
funcs::ElementwiseCompute<funcs::SubtractFunctor<T>, T>(
dev_ctx,
slice_tensor,
value,
-1,
funcs::SubtractFunctor<T>(),
&slice_tensor);
bool is_gpu_place = dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU;
if (is_gpu_place || slice_tensor.dims().size() >= value.dims().size()) {
// [Why here we confirm running device]
// ElementwiseComputeEx can do broadcasting in two cases:
// 1. The place is GPU.
// 2. The place is CPU, and the 'x' does not need broadcast.
// Please see the note in
// paddle/fluid/operators/elementwise/elementwise_op_function.h
// So, here we choose different logic depending on the device to avoid
// numerical problems, temporarily.
//
// TODO(zoooo0820): Reimplement logic of set_value to avoid using
// elementwise-sub.
funcs::ElementwiseCompute<funcs::SubtractFunctor<T>, T>(
dev_ctx,
slice_tensor,
value,
-1,
funcs::SubtractFunctor<T>(),
&slice_tensor);
} else {
funcs::ElementwiseCompute<funcs::InverseSubtractFunctor<T>, T>(
dev_ctx,
slice_tensor,
value,
-1,
funcs::InverseSubtractFunctor<T>(),
&slice_tensor);
}
slice_tensor.Resize(slice_dims);
// - Step 2.2 Pad slice tensor with 0
......
......@@ -982,6 +982,45 @@ class TestSetValueValueShape5(TestSetValueApi):
self.data[:, 0] = self.value
# This is to test case which dims of indexed Tensor is
# less than value Tensor on CPU / GPU.
class TestSetValueValueShape6(TestSetValueApi):
def set_value(self):
self.value = np.ones((1, 4)) * 5
def set_shape(self):
self.shape = [4, 4]
def _call_setitem(self, x):
x[:, 0] = self.value # x is Paddle.Tensor
def _get_answer(self):
self.data[:, 0] = self.value
def test_api(self):
places = ['cpu']
if paddle.is_compiled_with_cuda():
places.append('gpu')
for place in places:
paddle.set_device(place)
static_out = self._run_static()
dynamic_out = self._run_dynamic()
self._get_answer()
error_msg = (
"\nIn {} mode: \nExpected res = \n{}, \n\nbut received : \n{}"
)
self.assertTrue(
(self.data == static_out).all(),
msg=error_msg.format("static", self.data, static_out),
)
self.assertTrue(
(self.data == dynamic_out).all(),
msg=error_msg.format("dynamic", self.data, dynamic_out),
)
# 4. Test error
class TestError(TestSetValueBase):
def _value_type_error(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册