未验证 提交 e6def1eb 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

[Zero-Dim] Support all/any/min/prod/logsumexp/amax/amin/some loss output 0D,test=allcase (#53051)

上级 7fa415ca
......@@ -58,8 +58,8 @@ class AssertOp : public framework::OperatorBase {
"Input(Condition) of AssertOp is not found."));
const phi::DenseTensor &cond = cond_var_ptr->Get<phi::DenseTensor>();
PADDLE_ENFORCE_EQ(
cond.dims(),
phi::make_ddim({1}),
cond.numel(),
1,
platform::errors::InvalidArgument(
"The numel of Input(Condition) of AssertOp must be 1. But now "
"the Condition's shape is %s.",
......
......@@ -54,10 +54,9 @@ class ReduceMaxCompositeGradOpMaker : public prim::CompositeGradOpMakerBase {
} // namespace operators
} // namespace paddle
DECLARE_INFER_SHAPE_FUNCTOR(
reduce_max,
ReduceMaxInferShapeFunctor,
PD_INFER_META(phi::ReduceIntArrayAxisInferMetaBase));
DECLARE_INFER_SHAPE_FUNCTOR(reduce_max,
ReduceMaxInferShapeFunctor,
PD_INFER_META(phi::OriginReduceInferMetaBase));
REGISTER_OPERATOR(
reduce_max,
......
......@@ -98,10 +98,9 @@ class __reduce_meanMaker__ : public ops::ReduceBaseOpMaker {
virtual std::string GetOpType() const { return "Reduce reduce_mean"; }
};
DECLARE_INFER_SHAPE_FUNCTOR(
reduce_mean,
ReduceMeanInferShapeFunctor,
PD_INFER_META(phi::ReduceIntArrayAxisInferMetaBase));
DECLARE_INFER_SHAPE_FUNCTOR(reduce_mean,
ReduceMeanInferShapeFunctor,
PD_INFER_META(phi::OriginReduceInferMetaBase));
REGISTER_OPERATOR(reduce_mean,
ops::ReduceBaseOp,
......
......@@ -1282,7 +1282,7 @@ void prod_grad(const Tensor& x,
if (!keep_dim) {
auto axis_ = std::vector<int64_t>();
if (reduce_all) {
for (int64_t i = 1; i < x_dim_size; i++) {
for (int64_t i = 0; i < x_dim_size; i++) {
axis_.push_back(i);
}
} else {
......
......@@ -775,7 +775,7 @@
args : (Tensor x, IntArray axis={}, bool keepdim=false)
output : Tensor(out)
infer_meta :
func : ReduceIntArrayAxisInferMeta
func : OriginReduceInferMeta
kernel :
func : max
backward : max_grad
......@@ -811,7 +811,7 @@
args : (Tensor x, IntArray axis={}, bool keepdim=false)
output : Tensor(out)
infer_meta :
func : ReduceIntArrayAxisInferMeta
func : OriginReduceInferMeta
kernel :
func : mean
backward : mean_grad
......
......@@ -2146,7 +2146,7 @@ void MaxPoolWithIndexInferMeta(const MetaTensor& x,
}
void MeanAllInferMeta(const MetaTensor& x, MetaTensor* out) {
out->set_dims(phi::make_ddim({1}));
out->set_dims(phi::make_ddim({}));
out->set_dtype(x.dtype());
out->set_layout(x.layout());
}
......@@ -3050,29 +3050,19 @@ DDim ReduceInferDim(const MetaTensor& x,
reduce_all = reduce_all || full_dim;
std::vector<int64_t> out_dim_vector;
if (keep_dim) {
for (int64_t i = 0; i < x_rank; ++i) {
if (reduce_all || dims_set.find(i) != dims_set.end()) {
for (int64_t i = 0; i < x_rank; ++i) {
if (reduce_all || dims_set.find(i) != dims_set.end()) {
if (keep_dim) {
out_dim_vector.push_back(1);
} else {
out_dim_vector.push_back(x.dims().at(i));
}
}
} else {
for (int64_t i = 0; i < x_rank; ++i) {
if (reduce_all || dims_set.find(i) != dims_set.end()) {
continue;
} else {
out_dim_vector.push_back(x.dims().at(i));
}
}
if (x_rank > 0 && out_dim_vector.size() == 0) {
out_dim_vector.push_back(1);
} else {
out_dim_vector.push_back(x.dims().at(i));
}
}
DDim out_dim = phi::make_ddim(out_dim_vector);
DDim out_dim = phi::make_ddim(out_dim_vector);
return out_dim;
}
......@@ -3086,14 +3076,14 @@ DDim ReduceInferDimForIntArrayAxis(const MetaTensor& x,
if (keep_dim) {
vec_dim = std::vector<int64_t>(x.dims().size(), 1);
} else {
vec_dim = {1};
vec_dim = {};
}
} else {
if (keep_dim) {
vec_dim = std::vector<int64_t>(x.dims().size(), -1);
} else {
auto x_rank = static_cast<size_t>(x.dims().size());
if (vec_axis.size() >= x_rank) {
if (vec_axis.size() > x_rank) {
vec_dim = {-1};
} else {
vec_dim = std::vector<int64_t>(x.dims().size() - vec_axis.size(), -1);
......@@ -3125,22 +3115,6 @@ void ReduceInferMetaBase(const MetaTensor& x,
out->set_layout(x.layout());
}
void ReduceIntArrayAxisInferMetaBase(const MetaTensor& x,
const IntArray& axis,
bool keep_dim,
bool reduce_all,
MetaTensor* out,
MetaConfig config) {
if (config.is_runtime || !axis.FromTensor()) {
ReduceInferMetaBase(x, axis.GetData(), keep_dim, reduce_all, out);
} else {
DDim out_dim = ReduceInferDimForIntArrayAxis(x, axis, keep_dim, reduce_all);
out->set_dims(out_dim);
out->set_dtype(x.dtype());
out->set_layout(x.layout());
}
}
void ReduceIntArrayAxisInferMeta(const MetaTensor& x,
const IntArray& axis,
bool keep_dim,
......@@ -3153,6 +3127,23 @@ void ReduceIntArrayAxisInferMeta(const MetaTensor& x,
ReduceIntArrayAxisInferMetaBase(x, axis, keep_dim, reduce_all, out, config);
}
void ReduceIntArrayAxisInferMetaBase(const MetaTensor& x,
const IntArray& axis,
bool keep_dim,
bool reduce_all,
MetaTensor* out,
MetaConfig config) {
DDim out_dim;
if (config.is_runtime || !axis.FromTensor()) {
out_dim = ReduceInferDim(x, axis.GetData(), keep_dim, reduce_all);
} else {
out_dim = ReduceInferDimForIntArrayAxis(x, axis, keep_dim, reduce_all);
}
out->set_dims(out_dim);
out->set_dtype(x.dtype());
out->set_layout(x.layout());
}
void ReduceScatterInferMeta(const MetaTensor& x, int nranks, MetaTensor* out) {
auto dim = x.dims();
if (dim[0] > 0 || dim[0] < -1) {
......@@ -3951,6 +3942,105 @@ void StridedSliceInferMeta(const MetaTensor& x,
x, axes, starts, ends, strides, infer_flags, decrease_axis, out, config);
}
// TODO(zhouwei): OriginReduceInferDim doesn't support 0D, remove in future
DDim OriginReduceInferDim(const MetaTensor& x,
const std::vector<int64_t>& axis,
bool keep_dim,
bool reduce_all) {
auto x_rank = x.dims().size();
std::vector<int64_t> formated_axis = axis;
for (size_t i = 0; i < axis.size(); ++i) {
if (x_rank == 0) {
PADDLE_ENFORCE_EQ(
axis[i] == 0 || axis[i] == -1,
true,
phi::errors::InvalidArgument(
"When input 0D Tensor, the axis can only be -1, 0, None or []"));
} else {
PADDLE_ENFORCE_LT(axis[i],
x_rank,
errors::InvalidArgument(
"The reduce dim index %d should be in the "
"range [ -dimension(X), dimension(X) ) "
"which dimesion = %d. But received dim index = %d.",
i,
x_rank,
axis[i]));
PADDLE_ENFORCE_GE(axis[i],
-x_rank,
errors::InvalidArgument(
"The reduce dim index %d should be in the "
"range [ -dimension(X), dimension(X) ) "
"which dimesion = %d. But received dim index = %d.",
i,
x_rank,
axis[i]));
}
if (axis[i] < 0) {
formated_axis[i] = axis[i] + x_rank;
}
}
bool full_dim = true;
std::set<int64_t> dims_set(formated_axis.begin(), formated_axis.end());
for (int64_t i = 0; i < x_rank; ++i) {
if (dims_set.find(i) == dims_set.end()) {
full_dim = false;
break;
}
}
reduce_all = reduce_all || full_dim;
std::vector<int64_t> out_dim_vector;
for (int64_t i = 0; i < x_rank; ++i) {
if (reduce_all || dims_set.find(i) != dims_set.end()) {
if (keep_dim) {
out_dim_vector.push_back(1);
} else {
continue;
}
} else {
out_dim_vector.push_back(x.dims().at(i));
}
}
if (x_rank > 0 && out_dim_vector.size() == 0) {
out_dim_vector.push_back(1);
}
DDim out_dim = phi::make_ddim(out_dim_vector);
return out_dim;
}
// TODO(zhouwei): OriginReduceInferDim doesn't support 0D, remove in future
DDim OriginReduceInferDimForIntArrayAxis(const MetaTensor& x,
const IntArray& axis,
bool keep_dim,
bool reduce_all) {
std::vector<int64_t> vec_axis = axis.GetData();
std::vector<int64_t> vec_dim;
if (reduce_all) {
if (keep_dim) {
vec_dim = std::vector<int64_t>(x.dims().size(), 1);
} else {
vec_dim = {1};
}
} else {
if (keep_dim) {
vec_dim = std::vector<int64_t>(x.dims().size(), -1);
} else {
auto x_rank = static_cast<size_t>(x.dims().size());
if (vec_axis.size() >= x_rank) {
vec_dim = {-1};
} else {
vec_dim = std::vector<int64_t>(x.dims().size() - vec_axis.size(), -1);
}
}
}
return phi::make_ddim(vec_dim);
}
/* Why not use SumRawInferMeta directly?
Because we need make InferMetaFunction's args follow the design of
ops.yaml
......@@ -3977,9 +4067,10 @@ void SumRawInferMeta(const MetaTensor& x,
MetaConfig config) {
DDim out_dim;
if (config.is_runtime || !axis.FromTensor()) {
out_dim = ReduceInferDim(x, axis.GetData(), keep_dim, reduce_all);
out_dim = OriginReduceInferDim(x, axis.GetData(), keep_dim, reduce_all);
} else {
out_dim = ReduceInferDimForIntArrayAxis(x, axis, keep_dim, reduce_all);
out_dim =
OriginReduceInferDimForIntArrayAxis(x, axis, keep_dim, reduce_all);
}
DataType out_dtype;
......@@ -3998,6 +4089,38 @@ void SumRawInferMeta(const MetaTensor& x,
out->set_layout(x.layout());
}
// TODO(zhouwei): OriginReduce doesn't support 0D, remove in future
void OriginReduceInferMeta(const MetaTensor& x,
const IntArray& axis,
bool keep_dim,
MetaTensor* out,
MetaConfig config) {
bool reduce_all = false;
if (axis.size() == 0) {
reduce_all = true;
}
OriginReduceInferMetaBase(x, axis, keep_dim, reduce_all, out, config);
}
// TODO(zhouwei): OriginReduce doesn't support 0D, remove in future
void OriginReduceInferMetaBase(const MetaTensor& x,
const IntArray& axis,
bool keep_dim,
bool reduce_all,
MetaTensor* out,
MetaConfig config) {
DDim out_dim;
if (config.is_runtime || !axis.FromTensor()) {
out_dim = OriginReduceInferDim(x, axis.GetData(), keep_dim, reduce_all);
} else {
out_dim =
OriginReduceInferDimForIntArrayAxis(x, axis, keep_dim, reduce_all);
}
out->set_dims(out_dim);
out->set_dtype(x.dtype());
out->set_layout(x.layout());
}
void SvdInferMeta(const MetaTensor& x,
bool full_matrices,
MetaTensor* u,
......
......@@ -572,6 +572,19 @@ void SumRawInferMeta(const MetaTensor& x,
MetaTensor* out,
MetaConfig config = MetaConfig());
void OriginReduceInferMeta(const MetaTensor& x,
const IntArray& axis,
bool keep_dim,
MetaTensor* out,
MetaConfig config = MetaConfig());
void OriginReduceInferMetaBase(const MetaTensor& x,
const IntArray& axis,
bool keep_dim,
bool reduce_all,
MetaTensor* out,
MetaConfig config = MetaConfig());
void SvdInferMeta(const MetaTensor& x,
bool full_matrices,
MetaTensor* u,
......
......@@ -32,10 +32,12 @@ void MeanAllGradKernel(const Context& dev_ctx,
out_grad.numel()));
dev_ctx.template Alloc<T>(x_grad);
T ig_size = static_cast<T>(x_grad->numel());
Eigen::DSizes<int, 1> bcast(static_cast<int>(ig_size));
EigenVector<T>::Flatten(*x_grad).device(*dev_ctx.eigen_device()) =
(EigenVector<T>::From(out_grad) / ig_size).broadcast(bcast);
T x_numel = static_cast<T>(x_grad->numel());
Eigen::DSizes<int, 1> bcast(static_cast<int>(x_numel));
auto eigen_x = EigenVector<T>::Flatten(*x_grad);
auto eigen_dout = EigenVector<T>::Flatten(out_grad);
eigen_x.device(*dev_ctx.eigen_device()) =
(eigen_dout / x_numel).broadcast(bcast);
}
} // namespace phi
......
......@@ -13,7 +13,8 @@
# limitations under the License
from collections import OrderedDict
from functools import reduce
import numpy as np
import paddle
from paddle.utils.flops import flops
......@@ -807,7 +808,7 @@ class CommOpCost(OpCost):
factor = 8
else:
raise ValueError(f"Unsupported comm dtype {dtype}")
comm_count = reduce(lambda x, y: x * y, shape) * factor
comm_count = int(np.prod(shape)) * factor
self._comm_count = comm_count
return self._comm_count
......
......@@ -286,7 +286,7 @@ class TestDistRunnerBase:
fetch_list=[avg_cost.name],
feed=feeder.feed(get_data()),
)
out_losses.append(loss[0])
out_losses.append(float(loss))
print_to_err(type(self).__name__, "run step %d finished" % i)
print_to_err(type(self).__name__, "trainer run finished")
print_to_err(type(self).__name__, f"dist losses: {out_losses}")
......@@ -382,7 +382,7 @@ class TestDistRunnerBase:
fetch_list=[avg_cost.name],
feed=feeder.feed(get_data()),
)
out_losses.append(loss[0])
out_losses.append(float(loss))
print_to_err(type(self).__name__, "run step %d finished" % i)
print_to_err(type(self).__name__, "trainer run finished")
......@@ -619,7 +619,7 @@ class TestDistRunnerBase:
(loss,) = exe.run(
binary, fetch_list=[avg_cost.name], feed=feeder.feed(get_data())
)
out_losses.append(loss[0])
out_losses.append(float(loss))
print_to_err(type(self).__name__, "run step %d finished" % i)
if lr_scheduler is not None:
lr_scheduler.step()
......
......@@ -31,17 +31,17 @@ class TestFunctionalL1Loss(unittest.TestCase):
dy_result = paddle.nn.functional.l1_loss(input, label)
expected = np.mean(np.abs(self.input_np - self.label_np))
np.testing.assert_allclose(dy_result.numpy(), expected, rtol=1e-05)
self.assertTrue(dy_result.shape, [1])
self.assertEqual(dy_result.shape, [])
dy_result = paddle.nn.functional.l1_loss(input, label, reduction='sum')
expected = np.sum(np.abs(self.input_np - self.label_np))
np.testing.assert_allclose(dy_result.numpy(), expected, rtol=1e-05)
self.assertTrue(dy_result.shape, [1])
self.assertEqual(dy_result.shape, [1])
dy_result = paddle.nn.functional.l1_loss(input, label, reduction='none')
expected = np.abs(self.input_np - self.label_np)
np.testing.assert_allclose(dy_result.numpy(), expected, rtol=1e-05)
self.assertTrue(dy_result.shape, [10, 10, 5])
self.assertEqual(dy_result.shape, [10, 10, 5])
def run_static(self, use_gpu=False):
input = paddle.static.data(
......@@ -119,19 +119,19 @@ class TestClassL1Loss(unittest.TestCase):
dy_result = l1_loss(input, label)
expected = np.mean(np.abs(self.input_np - self.label_np))
np.testing.assert_allclose(dy_result.numpy(), expected, rtol=1e-05)
self.assertTrue(dy_result.shape, [1])
self.assertEqual(dy_result.shape, [])
l1_loss = paddle.nn.loss.L1Loss(reduction='sum')
dy_result = l1_loss(input, label)
expected = np.sum(np.abs(self.input_np - self.label_np))
np.testing.assert_allclose(dy_result.numpy(), expected, rtol=1e-05)
self.assertTrue(dy_result.shape, [1])
self.assertEqual(dy_result.shape, [1])
l1_loss = paddle.nn.loss.L1Loss(reduction='none')
dy_result = l1_loss(input, label)
expected = np.abs(self.input_np - self.label_np)
np.testing.assert_allclose(dy_result.numpy(), expected, rtol=1e-05)
self.assertTrue(dy_result.shape, [10, 10, 5])
self.assertEqual(dy_result.shape, [10, 10, 5])
def run_static(self, use_gpu=False):
input = paddle.static.data(
......
......@@ -197,6 +197,7 @@ class TestReduceAPI(unittest.TestCase):
out_empty_list = api(x, [])
self.assertEqual(out_empty_list, out)
self.assertEqual(out_empty_list.shape, [])
if x.grad is not None:
self.assertEqual(x.grad.shape, [])
......@@ -218,6 +219,27 @@ class TestReduceAPI(unittest.TestCase):
self.assertEqual(x.grad.shape, [])
np.testing.assert_allclose(x.grad.numpy(), np.array(3.0))
# 2) x is ND
if api in [
paddle.sum,
paddle.mean,
paddle.nanmean,
paddle.nansum,
paddle.max,
]:
return
x = paddle.rand([3, 5])
x.stop_gradient = False
out = api(x, None)
out.retain_grads()
out.backward()
self.assertEqual(out.shape, [])
if x.grad is not None:
self.assertEqual(out.grad.shape, [])
self.assertEqual(x.grad.shape, [3, 5])
paddle.enable_static()
def test_static_reduce(self):
......@@ -262,6 +284,31 @@ class TestReduceAPI(unittest.TestCase):
np.testing.assert_allclose(res[2], np.array(1.0))
np.testing.assert_allclose(res[3], np.array(1.0))
# 2) x is ND
if api in [
paddle.sum,
paddle.mean,
paddle.nanmean,
paddle.nansum,
paddle.max,
]:
return
x = paddle.rand([3, 5])
x.stop_gradient = False
out = api(x, None)
paddle.static.append_backward(out)
fetch_list = [out]
if block.has_var(x.grad_name):
fetch_list.extend([out.grad_name, x.grad_name])
res = exe.run(main_prog, fetch_list=fetch_list)
self.assertEqual(res[0].shape, ())
if len(res) > 1:
self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, (3, 5))
paddle.disable_static()
......@@ -1172,8 +1219,8 @@ class TestSundryAPI(unittest.TestCase):
def test_shape(self):
out = paddle.shape(self.x)
self.assertEqual(out.shape, [0])
np.testing.assert_array_equal(out.numpy(), np.array([]))
self.assertEqual(out.shape, [0])
def test_equal_scalar(self):
x = paddle.rand([])
......@@ -1233,6 +1280,16 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(out.grad.shape, [])
self.assertEqual(x.grad.shape, [])
x1 = paddle.uniform([], None, -10, 10)
x1.stop_gradient = False
out1 = paddle.clip(x1, paddle.full([], 5.0), paddle.full([], 5.0))
out1.retain_grads()
out1.backward()
self.assertEqual(out1.shape, [])
self.assertEqual(out1.grad.shape, [])
self.assertEqual(x1.grad.shape, [])
def test_increment(self):
x = paddle.rand([])
x.stop_gradient = False
......@@ -1465,6 +1522,11 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(out.grad.shape, [])
self.assertEqual(x.grad.shape, [])
def test_scale_(self):
x = paddle.rand([])
out = x.scale_(scale=2.0, bias=1.0)
self.assertEqual(out.shape, [])
def test_floor_divide(self):
# 1-d // 0-d
x = paddle.to_tensor([1, -2, 3], dtype="int64")
......@@ -1797,32 +1859,6 @@ class TestSundryAPI(unittest.TestCase):
# check grad shape with 1D repeats
self.assertEqual(x.grad.shape, [])
def test_sigmoid_focal_loss(self):
logit = paddle.to_tensor(
[[0.97, 0.91, 0.03], [0.55, 0.43, 0.71]],
dtype='float32',
stop_gradient=False,
)
logit.retain_grads()
label = paddle.to_tensor(
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], dtype='float32'
)
fg_num_0 = paddle.full([], 2.0)
fg_num_1 = paddle.full([1], 2.0)
out0 = F.sigmoid_focal_loss(logit, label, normalizer=fg_num_0)
out1 = F.sigmoid_focal_loss(logit, label, normalizer=fg_num_1)
out0.retain_grads()
np.testing.assert_array_equal(
out0.numpy(),
out1.numpy(),
)
out0.backward()
self.assertEqual(out0.grad.shape, [1])
self.assertEqual(logit.grad.shape, [2, 3])
def test_allclose(self):
# 1) x is 0D
x = paddle.full([], 0.5)
......@@ -2305,6 +2341,7 @@ class TestSundryAPIStatic(unittest.TestCase):
self.assertEqual(res[3].shape, ())
self.assertEqual(res[3], 1.0)
@prog_scope()
def test_argmin(self):
# 1) x is 0D
x = paddle.rand([])
......@@ -2625,14 +2662,33 @@ class TestSundryAPIStatic(unittest.TestCase):
out = paddle.clip(x, -5, 5)
paddle.static.append_backward(out)
x1 = paddle.uniform([], None, -10, 10)
x1.stop_gradient = False
out1 = paddle.clip(x1, paddle.full([], 5.0), paddle.full([], 5.0))
paddle.static.append_backward(out1)
prog = paddle.static.default_main_program()
res = self.exe.run(
prog, fetch_list=[x, out, x.grad_name, out.grad_name]
prog,
fetch_list=[
x,
out,
x.grad_name,
out.grad_name,
x1,
out1,
x1.grad_name,
out1.grad_name,
],
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, ())
self.assertEqual(res[3].shape, ())
self.assertEqual(res[4].shape, ())
self.assertEqual(res[5].shape, ())
self.assertEqual(res[6].shape, ())
self.assertEqual(res[7].shape, ())
@prog_scope()
def test_increment(self):
......@@ -2967,6 +3023,7 @@ class TestSundryAPIStatic(unittest.TestCase):
self.assertEqual(out2.shape, ())
self.assertEqual(out3.shape, ())
@prog_scope()
def test_add_n(self):
x1 = paddle.rand([])
x1.stop_gradient = False
......@@ -3589,15 +3646,14 @@ class TestSundryAPIStatic(unittest.TestCase):
np.testing.assert_array_equal(res[0], np.array(2))
@prog_scope()
def _test_shape(self):
def test_shape(self):
x = paddle.full([], 0.5)
out = paddle.shape(x)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out])
# 0-Size should be [ np.array([]) ], its [None] now
self.assertEqual(res[0].shape, (0))
np.testing.assert_array_equal(res[0], np.array([]))
self.assertEqual(res[0].shape, (0,))
def test_broadcast_tensors(self):
# 1) x is 0D, y is 0D
......@@ -4352,5 +4408,75 @@ class TestDistribution(unittest.TestCase):
# self.assertEqual(d.entropy().shape, [])
class TestLossAPI(unittest.TestCase):
def test_sigmoid_focal_loss(self):
logit = paddle.to_tensor(
[[0.97, 0.91, 0.03], [0.55, 0.43, 0.71]],
dtype='float32',
stop_gradient=False,
)
logit.retain_grads()
label = paddle.to_tensor(
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], dtype='float32'
)
fg_num_0 = paddle.full([], 2.0)
fg_num_1 = paddle.full([1], 2.0)
out0 = F.sigmoid_focal_loss(
logit, label, normalizer=fg_num_0, reduction='mean'
)
out1 = F.sigmoid_focal_loss(
logit, label, normalizer=fg_num_1, reduction='mean'
)
out0.retain_grads()
np.testing.assert_array_equal(
out0.numpy(),
out1.numpy(),
)
out0.backward()
self.assertEqual(out0.shape, [])
self.assertEqual(out1.shape, [])
self.assertEqual(out0.grad.shape, [])
self.assertEqual(logit.grad.shape, [2, 3])
class TestLossAPIStatic(unittest.TestCase):
def setUp(self):
paddle.enable_static()
self.exe = paddle.static.Executor()
@prog_scope()
def test_sigmoid_focal_loss(self):
logit = paddle.rand([2, 3])
logit.stop_gradient = False
label = paddle.randint(0, 1, [2, 3]).astype('float32')
label.stop_gradient = False
fg_num_0 = paddle.full([], 2.0)
fg_num_1 = paddle.full([1], 2.0)
out0 = F.sigmoid_focal_loss(
logit, label, normalizer=fg_num_0, reduction='mean'
)
out1 = F.sigmoid_focal_loss(
logit, label, normalizer=fg_num_1, reduction='mean'
)
paddle.static.append_backward(out0.sum())
prog = paddle.static.default_main_program()
res = self.exe.run(
prog, fetch_list=[out0, out1, out0.grad_name, logit.grad_name]
)
np.testing.assert_allclose(res[0], res[1])
# because static use paddle.mean
# self.assertEqual(res[0].shape, ())
# self.assertEqual(res[1].shape, ())
# self.assertEqual(res[2].shape, ())
self.assertEqual(res[3].shape, (2, 3))
if __name__ == "__main__":
unittest.main()
......@@ -256,7 +256,10 @@ def jac(grad_fn, f, inputs):
_vs = vs.copy()
_vs[i] = _v
_, grads = grad_fn(f, inputs, _vs)
d_outs = paddle.concat([d_out.flatten() for d_out in grads])
if isinstance(grads, typing.Sequence):
d_outs = paddle.concat([d_out.flatten() for d_out in grads])
else:
d_outs = grads.flatten()
JJ_cols.append(d_outs)
# JJ is the fully unrolled jacobian
JJ = paddle.stack(JJ_cols)
......
......@@ -158,7 +158,7 @@ class TestConvertShapeCompare(unittest.TestCase):
fetch_list=[eq_out, not_eq_out, long_eq_out],
)
np.testing.assert_array_equal(
np.array(x_y_eq_out), np.array([[True], [False], [False]])
np.array(x_y_eq_out), np.array([True, False, False])
)
set_a_zero = np.ones([3, 2]).astype(np.float32)
......@@ -168,7 +168,7 @@ class TestConvertShapeCompare(unittest.TestCase):
fetch_list=[eq_out, not_eq_out, long_eq_out],
)
np.testing.assert_array_equal(
np.array(x_y_not_eq_out), np.array([[False], [True], [True]])
np.array(x_y_not_eq_out), np.array([False, True, True])
)
paddle.disable_static()
......
......@@ -573,7 +573,7 @@ class TestLACModel(unittest.TestCase):
words, targets, length = batch
start_time = time.time()
avg_cost, crf_decode = model(words, targets, length)
loss_data.append(avg_cost.numpy()[0])
loss_data.append(float(avg_cost))
# backward and optimization
avg_cost.backward()
......
......@@ -100,7 +100,7 @@ class TestPureFP16(TestMNIST):
scaled.backward()
scaler.minimize(optimizer, scaled)
loss_data.append(avg_loss.numpy()[0])
loss_data.append(float(avg_loss))
# save checkpoint
mnist.clear_gradients()
if batch_id % 2 == 0:
......
......@@ -176,7 +176,7 @@ def train(args, place, to_static):
state, reward, done, _ = env.step(action)
# log loss_probs
loss_data.append(loss.numpy()[0])
loss_data.append(float(loss))
policy.rewards.append(reward)
ep_reward += reward
......@@ -191,7 +191,7 @@ def train(args, place, to_static):
if i_episode % args.log_interval == 0:
print(
'Episode {}\tLast reward: {:.2f}\tAverage reward: {:.2f}\t loss_probs: {}'.format(
i_episode, ep_reward, running_reward, loss.numpy()[0]
i_episode, ep_reward, running_reward, float(loss)
)
)
......
......@@ -86,7 +86,7 @@ def train(to_static, build_strategy=None):
scaler.minimize(optimizer, scaled)
resnet.clear_gradients()
loss_data.append(avg_loss.numpy()[0])
loss_data.append(float(avg_loss))
total_loss += avg_loss
total_acc1 += acc_top1
total_acc5 += acc_top5
......
......@@ -261,7 +261,7 @@ def train_dygraph(args, batch_generator):
transformer.clear_gradients()
if step_idx % args.print_step == 0:
total_avg_cost = avg_cost.numpy() * trainer_count
avg_loss.append(total_avg_cost[0])
avg_loss.append(float(total_avg_cost))
if step_idx == 0:
logging.info(
"step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册