未验证 提交 e6def1eb 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

[Zero-Dim] Support all/any/min/prod/logsumexp/amax/amin/some loss output 0D,test=allcase (#53051)

上级 7fa415ca
...@@ -58,8 +58,8 @@ class AssertOp : public framework::OperatorBase { ...@@ -58,8 +58,8 @@ class AssertOp : public framework::OperatorBase {
"Input(Condition) of AssertOp is not found.")); "Input(Condition) of AssertOp is not found."));
const phi::DenseTensor &cond = cond_var_ptr->Get<phi::DenseTensor>(); const phi::DenseTensor &cond = cond_var_ptr->Get<phi::DenseTensor>();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
cond.dims(), cond.numel(),
phi::make_ddim({1}), 1,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The numel of Input(Condition) of AssertOp must be 1. But now " "The numel of Input(Condition) of AssertOp must be 1. But now "
"the Condition's shape is %s.", "the Condition's shape is %s.",
......
...@@ -54,10 +54,9 @@ class ReduceMaxCompositeGradOpMaker : public prim::CompositeGradOpMakerBase { ...@@ -54,10 +54,9 @@ class ReduceMaxCompositeGradOpMaker : public prim::CompositeGradOpMakerBase {
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
DECLARE_INFER_SHAPE_FUNCTOR( DECLARE_INFER_SHAPE_FUNCTOR(reduce_max,
reduce_max, ReduceMaxInferShapeFunctor,
ReduceMaxInferShapeFunctor, PD_INFER_META(phi::OriginReduceInferMetaBase));
PD_INFER_META(phi::ReduceIntArrayAxisInferMetaBase));
REGISTER_OPERATOR( REGISTER_OPERATOR(
reduce_max, reduce_max,
......
...@@ -98,10 +98,9 @@ class __reduce_meanMaker__ : public ops::ReduceBaseOpMaker { ...@@ -98,10 +98,9 @@ class __reduce_meanMaker__ : public ops::ReduceBaseOpMaker {
virtual std::string GetOpType() const { return "Reduce reduce_mean"; } virtual std::string GetOpType() const { return "Reduce reduce_mean"; }
}; };
DECLARE_INFER_SHAPE_FUNCTOR( DECLARE_INFER_SHAPE_FUNCTOR(reduce_mean,
reduce_mean, ReduceMeanInferShapeFunctor,
ReduceMeanInferShapeFunctor, PD_INFER_META(phi::OriginReduceInferMetaBase));
PD_INFER_META(phi::ReduceIntArrayAxisInferMetaBase));
REGISTER_OPERATOR(reduce_mean, REGISTER_OPERATOR(reduce_mean,
ops::ReduceBaseOp, ops::ReduceBaseOp,
......
...@@ -1282,7 +1282,7 @@ void prod_grad(const Tensor& x, ...@@ -1282,7 +1282,7 @@ void prod_grad(const Tensor& x,
if (!keep_dim) { if (!keep_dim) {
auto axis_ = std::vector<int64_t>(); auto axis_ = std::vector<int64_t>();
if (reduce_all) { if (reduce_all) {
for (int64_t i = 1; i < x_dim_size; i++) { for (int64_t i = 0; i < x_dim_size; i++) {
axis_.push_back(i); axis_.push_back(i);
} }
} else { } else {
......
...@@ -775,7 +775,7 @@ ...@@ -775,7 +775,7 @@
args : (Tensor x, IntArray axis={}, bool keepdim=false) args : (Tensor x, IntArray axis={}, bool keepdim=false)
output : Tensor(out) output : Tensor(out)
infer_meta : infer_meta :
func : ReduceIntArrayAxisInferMeta func : OriginReduceInferMeta
kernel : kernel :
func : max func : max
backward : max_grad backward : max_grad
...@@ -811,7 +811,7 @@ ...@@ -811,7 +811,7 @@
args : (Tensor x, IntArray axis={}, bool keepdim=false) args : (Tensor x, IntArray axis={}, bool keepdim=false)
output : Tensor(out) output : Tensor(out)
infer_meta : infer_meta :
func : ReduceIntArrayAxisInferMeta func : OriginReduceInferMeta
kernel : kernel :
func : mean func : mean
backward : mean_grad backward : mean_grad
......
...@@ -2146,7 +2146,7 @@ void MaxPoolWithIndexInferMeta(const MetaTensor& x, ...@@ -2146,7 +2146,7 @@ void MaxPoolWithIndexInferMeta(const MetaTensor& x,
} }
void MeanAllInferMeta(const MetaTensor& x, MetaTensor* out) { void MeanAllInferMeta(const MetaTensor& x, MetaTensor* out) {
out->set_dims(phi::make_ddim({1})); out->set_dims(phi::make_ddim({}));
out->set_dtype(x.dtype()); out->set_dtype(x.dtype());
out->set_layout(x.layout()); out->set_layout(x.layout());
} }
...@@ -3050,29 +3050,19 @@ DDim ReduceInferDim(const MetaTensor& x, ...@@ -3050,29 +3050,19 @@ DDim ReduceInferDim(const MetaTensor& x,
reduce_all = reduce_all || full_dim; reduce_all = reduce_all || full_dim;
std::vector<int64_t> out_dim_vector; std::vector<int64_t> out_dim_vector;
if (keep_dim) { for (int64_t i = 0; i < x_rank; ++i) {
for (int64_t i = 0; i < x_rank; ++i) { if (reduce_all || dims_set.find(i) != dims_set.end()) {
if (reduce_all || dims_set.find(i) != dims_set.end()) { if (keep_dim) {
out_dim_vector.push_back(1); out_dim_vector.push_back(1);
} else { } else {
out_dim_vector.push_back(x.dims().at(i));
}
}
} else {
for (int64_t i = 0; i < x_rank; ++i) {
if (reduce_all || dims_set.find(i) != dims_set.end()) {
continue; continue;
} else {
out_dim_vector.push_back(x.dims().at(i));
} }
} } else {
out_dim_vector.push_back(x.dims().at(i));
if (x_rank > 0 && out_dim_vector.size() == 0) {
out_dim_vector.push_back(1);
} }
} }
DDim out_dim = phi::make_ddim(out_dim_vector);
DDim out_dim = phi::make_ddim(out_dim_vector);
return out_dim; return out_dim;
} }
...@@ -3086,14 +3076,14 @@ DDim ReduceInferDimForIntArrayAxis(const MetaTensor& x, ...@@ -3086,14 +3076,14 @@ DDim ReduceInferDimForIntArrayAxis(const MetaTensor& x,
if (keep_dim) { if (keep_dim) {
vec_dim = std::vector<int64_t>(x.dims().size(), 1); vec_dim = std::vector<int64_t>(x.dims().size(), 1);
} else { } else {
vec_dim = {1}; vec_dim = {};
} }
} else { } else {
if (keep_dim) { if (keep_dim) {
vec_dim = std::vector<int64_t>(x.dims().size(), -1); vec_dim = std::vector<int64_t>(x.dims().size(), -1);
} else { } else {
auto x_rank = static_cast<size_t>(x.dims().size()); auto x_rank = static_cast<size_t>(x.dims().size());
if (vec_axis.size() >= x_rank) { if (vec_axis.size() > x_rank) {
vec_dim = {-1}; vec_dim = {-1};
} else { } else {
vec_dim = std::vector<int64_t>(x.dims().size() - vec_axis.size(), -1); vec_dim = std::vector<int64_t>(x.dims().size() - vec_axis.size(), -1);
...@@ -3125,22 +3115,6 @@ void ReduceInferMetaBase(const MetaTensor& x, ...@@ -3125,22 +3115,6 @@ void ReduceInferMetaBase(const MetaTensor& x,
out->set_layout(x.layout()); out->set_layout(x.layout());
} }
void ReduceIntArrayAxisInferMetaBase(const MetaTensor& x,
const IntArray& axis,
bool keep_dim,
bool reduce_all,
MetaTensor* out,
MetaConfig config) {
if (config.is_runtime || !axis.FromTensor()) {
ReduceInferMetaBase(x, axis.GetData(), keep_dim, reduce_all, out);
} else {
DDim out_dim = ReduceInferDimForIntArrayAxis(x, axis, keep_dim, reduce_all);
out->set_dims(out_dim);
out->set_dtype(x.dtype());
out->set_layout(x.layout());
}
}
void ReduceIntArrayAxisInferMeta(const MetaTensor& x, void ReduceIntArrayAxisInferMeta(const MetaTensor& x,
const IntArray& axis, const IntArray& axis,
bool keep_dim, bool keep_dim,
...@@ -3153,6 +3127,23 @@ void ReduceIntArrayAxisInferMeta(const MetaTensor& x, ...@@ -3153,6 +3127,23 @@ void ReduceIntArrayAxisInferMeta(const MetaTensor& x,
ReduceIntArrayAxisInferMetaBase(x, axis, keep_dim, reduce_all, out, config); ReduceIntArrayAxisInferMetaBase(x, axis, keep_dim, reduce_all, out, config);
} }
void ReduceIntArrayAxisInferMetaBase(const MetaTensor& x,
const IntArray& axis,
bool keep_dim,
bool reduce_all,
MetaTensor* out,
MetaConfig config) {
DDim out_dim;
if (config.is_runtime || !axis.FromTensor()) {
out_dim = ReduceInferDim(x, axis.GetData(), keep_dim, reduce_all);
} else {
out_dim = ReduceInferDimForIntArrayAxis(x, axis, keep_dim, reduce_all);
}
out->set_dims(out_dim);
out->set_dtype(x.dtype());
out->set_layout(x.layout());
}
void ReduceScatterInferMeta(const MetaTensor& x, int nranks, MetaTensor* out) { void ReduceScatterInferMeta(const MetaTensor& x, int nranks, MetaTensor* out) {
auto dim = x.dims(); auto dim = x.dims();
if (dim[0] > 0 || dim[0] < -1) { if (dim[0] > 0 || dim[0] < -1) {
...@@ -3951,6 +3942,105 @@ void StridedSliceInferMeta(const MetaTensor& x, ...@@ -3951,6 +3942,105 @@ void StridedSliceInferMeta(const MetaTensor& x,
x, axes, starts, ends, strides, infer_flags, decrease_axis, out, config); x, axes, starts, ends, strides, infer_flags, decrease_axis, out, config);
} }
// TODO(zhouwei): OriginReduceInferDim doesn't support 0D, remove in future
DDim OriginReduceInferDim(const MetaTensor& x,
const std::vector<int64_t>& axis,
bool keep_dim,
bool reduce_all) {
auto x_rank = x.dims().size();
std::vector<int64_t> formated_axis = axis;
for (size_t i = 0; i < axis.size(); ++i) {
if (x_rank == 0) {
PADDLE_ENFORCE_EQ(
axis[i] == 0 || axis[i] == -1,
true,
phi::errors::InvalidArgument(
"When input 0D Tensor, the axis can only be -1, 0, None or []"));
} else {
PADDLE_ENFORCE_LT(axis[i],
x_rank,
errors::InvalidArgument(
"The reduce dim index %d should be in the "
"range [ -dimension(X), dimension(X) ) "
"which dimesion = %d. But received dim index = %d.",
i,
x_rank,
axis[i]));
PADDLE_ENFORCE_GE(axis[i],
-x_rank,
errors::InvalidArgument(
"The reduce dim index %d should be in the "
"range [ -dimension(X), dimension(X) ) "
"which dimesion = %d. But received dim index = %d.",
i,
x_rank,
axis[i]));
}
if (axis[i] < 0) {
formated_axis[i] = axis[i] + x_rank;
}
}
bool full_dim = true;
std::set<int64_t> dims_set(formated_axis.begin(), formated_axis.end());
for (int64_t i = 0; i < x_rank; ++i) {
if (dims_set.find(i) == dims_set.end()) {
full_dim = false;
break;
}
}
reduce_all = reduce_all || full_dim;
std::vector<int64_t> out_dim_vector;
for (int64_t i = 0; i < x_rank; ++i) {
if (reduce_all || dims_set.find(i) != dims_set.end()) {
if (keep_dim) {
out_dim_vector.push_back(1);
} else {
continue;
}
} else {
out_dim_vector.push_back(x.dims().at(i));
}
}
if (x_rank > 0 && out_dim_vector.size() == 0) {
out_dim_vector.push_back(1);
}
DDim out_dim = phi::make_ddim(out_dim_vector);
return out_dim;
}
// TODO(zhouwei): OriginReduceInferDim doesn't support 0D, remove in future
DDim OriginReduceInferDimForIntArrayAxis(const MetaTensor& x,
const IntArray& axis,
bool keep_dim,
bool reduce_all) {
std::vector<int64_t> vec_axis = axis.GetData();
std::vector<int64_t> vec_dim;
if (reduce_all) {
if (keep_dim) {
vec_dim = std::vector<int64_t>(x.dims().size(), 1);
} else {
vec_dim = {1};
}
} else {
if (keep_dim) {
vec_dim = std::vector<int64_t>(x.dims().size(), -1);
} else {
auto x_rank = static_cast<size_t>(x.dims().size());
if (vec_axis.size() >= x_rank) {
vec_dim = {-1};
} else {
vec_dim = std::vector<int64_t>(x.dims().size() - vec_axis.size(), -1);
}
}
}
return phi::make_ddim(vec_dim);
}
/* Why not use SumRawInferMeta directly? /* Why not use SumRawInferMeta directly?
Because we need make InferMetaFunction's args follow the design of Because we need make InferMetaFunction's args follow the design of
ops.yaml ops.yaml
...@@ -3977,9 +4067,10 @@ void SumRawInferMeta(const MetaTensor& x, ...@@ -3977,9 +4067,10 @@ void SumRawInferMeta(const MetaTensor& x,
MetaConfig config) { MetaConfig config) {
DDim out_dim; DDim out_dim;
if (config.is_runtime || !axis.FromTensor()) { if (config.is_runtime || !axis.FromTensor()) {
out_dim = ReduceInferDim(x, axis.GetData(), keep_dim, reduce_all); out_dim = OriginReduceInferDim(x, axis.GetData(), keep_dim, reduce_all);
} else { } else {
out_dim = ReduceInferDimForIntArrayAxis(x, axis, keep_dim, reduce_all); out_dim =
OriginReduceInferDimForIntArrayAxis(x, axis, keep_dim, reduce_all);
} }
DataType out_dtype; DataType out_dtype;
...@@ -3998,6 +4089,38 @@ void SumRawInferMeta(const MetaTensor& x, ...@@ -3998,6 +4089,38 @@ void SumRawInferMeta(const MetaTensor& x,
out->set_layout(x.layout()); out->set_layout(x.layout());
} }
// TODO(zhouwei): OriginReduce doesn't support 0D, remove in future
void OriginReduceInferMeta(const MetaTensor& x,
const IntArray& axis,
bool keep_dim,
MetaTensor* out,
MetaConfig config) {
bool reduce_all = false;
if (axis.size() == 0) {
reduce_all = true;
}
OriginReduceInferMetaBase(x, axis, keep_dim, reduce_all, out, config);
}
// TODO(zhouwei): OriginReduce doesn't support 0D, remove in future
void OriginReduceInferMetaBase(const MetaTensor& x,
const IntArray& axis,
bool keep_dim,
bool reduce_all,
MetaTensor* out,
MetaConfig config) {
DDim out_dim;
if (config.is_runtime || !axis.FromTensor()) {
out_dim = OriginReduceInferDim(x, axis.GetData(), keep_dim, reduce_all);
} else {
out_dim =
OriginReduceInferDimForIntArrayAxis(x, axis, keep_dim, reduce_all);
}
out->set_dims(out_dim);
out->set_dtype(x.dtype());
out->set_layout(x.layout());
}
void SvdInferMeta(const MetaTensor& x, void SvdInferMeta(const MetaTensor& x,
bool full_matrices, bool full_matrices,
MetaTensor* u, MetaTensor* u,
......
...@@ -572,6 +572,19 @@ void SumRawInferMeta(const MetaTensor& x, ...@@ -572,6 +572,19 @@ void SumRawInferMeta(const MetaTensor& x,
MetaTensor* out, MetaTensor* out,
MetaConfig config = MetaConfig()); MetaConfig config = MetaConfig());
void OriginReduceInferMeta(const MetaTensor& x,
const IntArray& axis,
bool keep_dim,
MetaTensor* out,
MetaConfig config = MetaConfig());
void OriginReduceInferMetaBase(const MetaTensor& x,
const IntArray& axis,
bool keep_dim,
bool reduce_all,
MetaTensor* out,
MetaConfig config = MetaConfig());
void SvdInferMeta(const MetaTensor& x, void SvdInferMeta(const MetaTensor& x,
bool full_matrices, bool full_matrices,
MetaTensor* u, MetaTensor* u,
......
...@@ -32,10 +32,12 @@ void MeanAllGradKernel(const Context& dev_ctx, ...@@ -32,10 +32,12 @@ void MeanAllGradKernel(const Context& dev_ctx,
out_grad.numel())); out_grad.numel()));
dev_ctx.template Alloc<T>(x_grad); dev_ctx.template Alloc<T>(x_grad);
T ig_size = static_cast<T>(x_grad->numel()); T x_numel = static_cast<T>(x_grad->numel());
Eigen::DSizes<int, 1> bcast(static_cast<int>(ig_size)); Eigen::DSizes<int, 1> bcast(static_cast<int>(x_numel));
EigenVector<T>::Flatten(*x_grad).device(*dev_ctx.eigen_device()) = auto eigen_x = EigenVector<T>::Flatten(*x_grad);
(EigenVector<T>::From(out_grad) / ig_size).broadcast(bcast); auto eigen_dout = EigenVector<T>::Flatten(out_grad);
eigen_x.device(*dev_ctx.eigen_device()) =
(eigen_dout / x_numel).broadcast(bcast);
} }
} // namespace phi } // namespace phi
......
...@@ -13,7 +13,8 @@ ...@@ -13,7 +13,8 @@
# limitations under the License # limitations under the License
from collections import OrderedDict from collections import OrderedDict
from functools import reduce
import numpy as np
import paddle import paddle
from paddle.utils.flops import flops from paddle.utils.flops import flops
...@@ -807,7 +808,7 @@ class CommOpCost(OpCost): ...@@ -807,7 +808,7 @@ class CommOpCost(OpCost):
factor = 8 factor = 8
else: else:
raise ValueError(f"Unsupported comm dtype {dtype}") raise ValueError(f"Unsupported comm dtype {dtype}")
comm_count = reduce(lambda x, y: x * y, shape) * factor comm_count = int(np.prod(shape)) * factor
self._comm_count = comm_count self._comm_count = comm_count
return self._comm_count return self._comm_count
......
...@@ -286,7 +286,7 @@ class TestDistRunnerBase: ...@@ -286,7 +286,7 @@ class TestDistRunnerBase:
fetch_list=[avg_cost.name], fetch_list=[avg_cost.name],
feed=feeder.feed(get_data()), feed=feeder.feed(get_data()),
) )
out_losses.append(loss[0]) out_losses.append(float(loss))
print_to_err(type(self).__name__, "run step %d finished" % i) print_to_err(type(self).__name__, "run step %d finished" % i)
print_to_err(type(self).__name__, "trainer run finished") print_to_err(type(self).__name__, "trainer run finished")
print_to_err(type(self).__name__, f"dist losses: {out_losses}") print_to_err(type(self).__name__, f"dist losses: {out_losses}")
...@@ -382,7 +382,7 @@ class TestDistRunnerBase: ...@@ -382,7 +382,7 @@ class TestDistRunnerBase:
fetch_list=[avg_cost.name], fetch_list=[avg_cost.name],
feed=feeder.feed(get_data()), feed=feeder.feed(get_data()),
) )
out_losses.append(loss[0]) out_losses.append(float(loss))
print_to_err(type(self).__name__, "run step %d finished" % i) print_to_err(type(self).__name__, "run step %d finished" % i)
print_to_err(type(self).__name__, "trainer run finished") print_to_err(type(self).__name__, "trainer run finished")
...@@ -619,7 +619,7 @@ class TestDistRunnerBase: ...@@ -619,7 +619,7 @@ class TestDistRunnerBase:
(loss,) = exe.run( (loss,) = exe.run(
binary, fetch_list=[avg_cost.name], feed=feeder.feed(get_data()) binary, fetch_list=[avg_cost.name], feed=feeder.feed(get_data())
) )
out_losses.append(loss[0]) out_losses.append(float(loss))
print_to_err(type(self).__name__, "run step %d finished" % i) print_to_err(type(self).__name__, "run step %d finished" % i)
if lr_scheduler is not None: if lr_scheduler is not None:
lr_scheduler.step() lr_scheduler.step()
......
...@@ -31,17 +31,17 @@ class TestFunctionalL1Loss(unittest.TestCase): ...@@ -31,17 +31,17 @@ class TestFunctionalL1Loss(unittest.TestCase):
dy_result = paddle.nn.functional.l1_loss(input, label) dy_result = paddle.nn.functional.l1_loss(input, label)
expected = np.mean(np.abs(self.input_np - self.label_np)) expected = np.mean(np.abs(self.input_np - self.label_np))
np.testing.assert_allclose(dy_result.numpy(), expected, rtol=1e-05) np.testing.assert_allclose(dy_result.numpy(), expected, rtol=1e-05)
self.assertTrue(dy_result.shape, [1]) self.assertEqual(dy_result.shape, [])
dy_result = paddle.nn.functional.l1_loss(input, label, reduction='sum') dy_result = paddle.nn.functional.l1_loss(input, label, reduction='sum')
expected = np.sum(np.abs(self.input_np - self.label_np)) expected = np.sum(np.abs(self.input_np - self.label_np))
np.testing.assert_allclose(dy_result.numpy(), expected, rtol=1e-05) np.testing.assert_allclose(dy_result.numpy(), expected, rtol=1e-05)
self.assertTrue(dy_result.shape, [1]) self.assertEqual(dy_result.shape, [1])
dy_result = paddle.nn.functional.l1_loss(input, label, reduction='none') dy_result = paddle.nn.functional.l1_loss(input, label, reduction='none')
expected = np.abs(self.input_np - self.label_np) expected = np.abs(self.input_np - self.label_np)
np.testing.assert_allclose(dy_result.numpy(), expected, rtol=1e-05) np.testing.assert_allclose(dy_result.numpy(), expected, rtol=1e-05)
self.assertTrue(dy_result.shape, [10, 10, 5]) self.assertEqual(dy_result.shape, [10, 10, 5])
def run_static(self, use_gpu=False): def run_static(self, use_gpu=False):
input = paddle.static.data( input = paddle.static.data(
...@@ -119,19 +119,19 @@ class TestClassL1Loss(unittest.TestCase): ...@@ -119,19 +119,19 @@ class TestClassL1Loss(unittest.TestCase):
dy_result = l1_loss(input, label) dy_result = l1_loss(input, label)
expected = np.mean(np.abs(self.input_np - self.label_np)) expected = np.mean(np.abs(self.input_np - self.label_np))
np.testing.assert_allclose(dy_result.numpy(), expected, rtol=1e-05) np.testing.assert_allclose(dy_result.numpy(), expected, rtol=1e-05)
self.assertTrue(dy_result.shape, [1]) self.assertEqual(dy_result.shape, [])
l1_loss = paddle.nn.loss.L1Loss(reduction='sum') l1_loss = paddle.nn.loss.L1Loss(reduction='sum')
dy_result = l1_loss(input, label) dy_result = l1_loss(input, label)
expected = np.sum(np.abs(self.input_np - self.label_np)) expected = np.sum(np.abs(self.input_np - self.label_np))
np.testing.assert_allclose(dy_result.numpy(), expected, rtol=1e-05) np.testing.assert_allclose(dy_result.numpy(), expected, rtol=1e-05)
self.assertTrue(dy_result.shape, [1]) self.assertEqual(dy_result.shape, [1])
l1_loss = paddle.nn.loss.L1Loss(reduction='none') l1_loss = paddle.nn.loss.L1Loss(reduction='none')
dy_result = l1_loss(input, label) dy_result = l1_loss(input, label)
expected = np.abs(self.input_np - self.label_np) expected = np.abs(self.input_np - self.label_np)
np.testing.assert_allclose(dy_result.numpy(), expected, rtol=1e-05) np.testing.assert_allclose(dy_result.numpy(), expected, rtol=1e-05)
self.assertTrue(dy_result.shape, [10, 10, 5]) self.assertEqual(dy_result.shape, [10, 10, 5])
def run_static(self, use_gpu=False): def run_static(self, use_gpu=False):
input = paddle.static.data( input = paddle.static.data(
......
...@@ -197,6 +197,7 @@ class TestReduceAPI(unittest.TestCase): ...@@ -197,6 +197,7 @@ class TestReduceAPI(unittest.TestCase):
out_empty_list = api(x, []) out_empty_list = api(x, [])
self.assertEqual(out_empty_list, out) self.assertEqual(out_empty_list, out)
self.assertEqual(out_empty_list.shape, [])
if x.grad is not None: if x.grad is not None:
self.assertEqual(x.grad.shape, []) self.assertEqual(x.grad.shape, [])
...@@ -218,6 +219,27 @@ class TestReduceAPI(unittest.TestCase): ...@@ -218,6 +219,27 @@ class TestReduceAPI(unittest.TestCase):
self.assertEqual(x.grad.shape, []) self.assertEqual(x.grad.shape, [])
np.testing.assert_allclose(x.grad.numpy(), np.array(3.0)) np.testing.assert_allclose(x.grad.numpy(), np.array(3.0))
# 2) x is ND
if api in [
paddle.sum,
paddle.mean,
paddle.nanmean,
paddle.nansum,
paddle.max,
]:
return
x = paddle.rand([3, 5])
x.stop_gradient = False
out = api(x, None)
out.retain_grads()
out.backward()
self.assertEqual(out.shape, [])
if x.grad is not None:
self.assertEqual(out.grad.shape, [])
self.assertEqual(x.grad.shape, [3, 5])
paddle.enable_static() paddle.enable_static()
def test_static_reduce(self): def test_static_reduce(self):
...@@ -262,6 +284,31 @@ class TestReduceAPI(unittest.TestCase): ...@@ -262,6 +284,31 @@ class TestReduceAPI(unittest.TestCase):
np.testing.assert_allclose(res[2], np.array(1.0)) np.testing.assert_allclose(res[2], np.array(1.0))
np.testing.assert_allclose(res[3], np.array(1.0)) np.testing.assert_allclose(res[3], np.array(1.0))
# 2) x is ND
if api in [
paddle.sum,
paddle.mean,
paddle.nanmean,
paddle.nansum,
paddle.max,
]:
return
x = paddle.rand([3, 5])
x.stop_gradient = False
out = api(x, None)
paddle.static.append_backward(out)
fetch_list = [out]
if block.has_var(x.grad_name):
fetch_list.extend([out.grad_name, x.grad_name])
res = exe.run(main_prog, fetch_list=fetch_list)
self.assertEqual(res[0].shape, ())
if len(res) > 1:
self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, (3, 5))
paddle.disable_static() paddle.disable_static()
...@@ -1172,8 +1219,8 @@ class TestSundryAPI(unittest.TestCase): ...@@ -1172,8 +1219,8 @@ class TestSundryAPI(unittest.TestCase):
def test_shape(self): def test_shape(self):
out = paddle.shape(self.x) out = paddle.shape(self.x)
self.assertEqual(out.shape, [0])
np.testing.assert_array_equal(out.numpy(), np.array([])) np.testing.assert_array_equal(out.numpy(), np.array([]))
self.assertEqual(out.shape, [0])
def test_equal_scalar(self): def test_equal_scalar(self):
x = paddle.rand([]) x = paddle.rand([])
...@@ -1233,6 +1280,16 @@ class TestSundryAPI(unittest.TestCase): ...@@ -1233,6 +1280,16 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(out.grad.shape, []) self.assertEqual(out.grad.shape, [])
self.assertEqual(x.grad.shape, []) self.assertEqual(x.grad.shape, [])
x1 = paddle.uniform([], None, -10, 10)
x1.stop_gradient = False
out1 = paddle.clip(x1, paddle.full([], 5.0), paddle.full([], 5.0))
out1.retain_grads()
out1.backward()
self.assertEqual(out1.shape, [])
self.assertEqual(out1.grad.shape, [])
self.assertEqual(x1.grad.shape, [])
def test_increment(self): def test_increment(self):
x = paddle.rand([]) x = paddle.rand([])
x.stop_gradient = False x.stop_gradient = False
...@@ -1465,6 +1522,11 @@ class TestSundryAPI(unittest.TestCase): ...@@ -1465,6 +1522,11 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(out.grad.shape, []) self.assertEqual(out.grad.shape, [])
self.assertEqual(x.grad.shape, []) self.assertEqual(x.grad.shape, [])
def test_scale_(self):
x = paddle.rand([])
out = x.scale_(scale=2.0, bias=1.0)
self.assertEqual(out.shape, [])
def test_floor_divide(self): def test_floor_divide(self):
# 1-d // 0-d # 1-d // 0-d
x = paddle.to_tensor([1, -2, 3], dtype="int64") x = paddle.to_tensor([1, -2, 3], dtype="int64")
...@@ -1797,32 +1859,6 @@ class TestSundryAPI(unittest.TestCase): ...@@ -1797,32 +1859,6 @@ class TestSundryAPI(unittest.TestCase):
# check grad shape with 1D repeats # check grad shape with 1D repeats
self.assertEqual(x.grad.shape, []) self.assertEqual(x.grad.shape, [])
def test_sigmoid_focal_loss(self):
logit = paddle.to_tensor(
[[0.97, 0.91, 0.03], [0.55, 0.43, 0.71]],
dtype='float32',
stop_gradient=False,
)
logit.retain_grads()
label = paddle.to_tensor(
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], dtype='float32'
)
fg_num_0 = paddle.full([], 2.0)
fg_num_1 = paddle.full([1], 2.0)
out0 = F.sigmoid_focal_loss(logit, label, normalizer=fg_num_0)
out1 = F.sigmoid_focal_loss(logit, label, normalizer=fg_num_1)
out0.retain_grads()
np.testing.assert_array_equal(
out0.numpy(),
out1.numpy(),
)
out0.backward()
self.assertEqual(out0.grad.shape, [1])
self.assertEqual(logit.grad.shape, [2, 3])
def test_allclose(self): def test_allclose(self):
# 1) x is 0D # 1) x is 0D
x = paddle.full([], 0.5) x = paddle.full([], 0.5)
...@@ -2305,6 +2341,7 @@ class TestSundryAPIStatic(unittest.TestCase): ...@@ -2305,6 +2341,7 @@ class TestSundryAPIStatic(unittest.TestCase):
self.assertEqual(res[3].shape, ()) self.assertEqual(res[3].shape, ())
self.assertEqual(res[3], 1.0) self.assertEqual(res[3], 1.0)
@prog_scope()
def test_argmin(self): def test_argmin(self):
# 1) x is 0D # 1) x is 0D
x = paddle.rand([]) x = paddle.rand([])
...@@ -2625,14 +2662,33 @@ class TestSundryAPIStatic(unittest.TestCase): ...@@ -2625,14 +2662,33 @@ class TestSundryAPIStatic(unittest.TestCase):
out = paddle.clip(x, -5, 5) out = paddle.clip(x, -5, 5)
paddle.static.append_backward(out) paddle.static.append_backward(out)
x1 = paddle.uniform([], None, -10, 10)
x1.stop_gradient = False
out1 = paddle.clip(x1, paddle.full([], 5.0), paddle.full([], 5.0))
paddle.static.append_backward(out1)
prog = paddle.static.default_main_program() prog = paddle.static.default_main_program()
res = self.exe.run( res = self.exe.run(
prog, fetch_list=[x, out, x.grad_name, out.grad_name] prog,
fetch_list=[
x,
out,
x.grad_name,
out.grad_name,
x1,
out1,
x1.grad_name,
out1.grad_name,
],
) )
self.assertEqual(res[0].shape, ()) self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ()) self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, ()) self.assertEqual(res[2].shape, ())
self.assertEqual(res[3].shape, ()) self.assertEqual(res[3].shape, ())
self.assertEqual(res[4].shape, ())
self.assertEqual(res[5].shape, ())
self.assertEqual(res[6].shape, ())
self.assertEqual(res[7].shape, ())
@prog_scope() @prog_scope()
def test_increment(self): def test_increment(self):
...@@ -2967,6 +3023,7 @@ class TestSundryAPIStatic(unittest.TestCase): ...@@ -2967,6 +3023,7 @@ class TestSundryAPIStatic(unittest.TestCase):
self.assertEqual(out2.shape, ()) self.assertEqual(out2.shape, ())
self.assertEqual(out3.shape, ()) self.assertEqual(out3.shape, ())
@prog_scope()
def test_add_n(self): def test_add_n(self):
x1 = paddle.rand([]) x1 = paddle.rand([])
x1.stop_gradient = False x1.stop_gradient = False
...@@ -3589,15 +3646,14 @@ class TestSundryAPIStatic(unittest.TestCase): ...@@ -3589,15 +3646,14 @@ class TestSundryAPIStatic(unittest.TestCase):
np.testing.assert_array_equal(res[0], np.array(2)) np.testing.assert_array_equal(res[0], np.array(2))
@prog_scope() @prog_scope()
def _test_shape(self): def test_shape(self):
x = paddle.full([], 0.5) x = paddle.full([], 0.5)
out = paddle.shape(x) out = paddle.shape(x)
prog = paddle.static.default_main_program() prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out]) res = self.exe.run(prog, fetch_list=[out])
# 0-Size should be [ np.array([]) ], its [None] now
self.assertEqual(res[0].shape, (0))
np.testing.assert_array_equal(res[0], np.array([])) np.testing.assert_array_equal(res[0], np.array([]))
self.assertEqual(res[0].shape, (0,))
def test_broadcast_tensors(self): def test_broadcast_tensors(self):
# 1) x is 0D, y is 0D # 1) x is 0D, y is 0D
...@@ -4352,5 +4408,75 @@ class TestDistribution(unittest.TestCase): ...@@ -4352,5 +4408,75 @@ class TestDistribution(unittest.TestCase):
# self.assertEqual(d.entropy().shape, []) # self.assertEqual(d.entropy().shape, [])
class TestLossAPI(unittest.TestCase):
def test_sigmoid_focal_loss(self):
logit = paddle.to_tensor(
[[0.97, 0.91, 0.03], [0.55, 0.43, 0.71]],
dtype='float32',
stop_gradient=False,
)
logit.retain_grads()
label = paddle.to_tensor(
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], dtype='float32'
)
fg_num_0 = paddle.full([], 2.0)
fg_num_1 = paddle.full([1], 2.0)
out0 = F.sigmoid_focal_loss(
logit, label, normalizer=fg_num_0, reduction='mean'
)
out1 = F.sigmoid_focal_loss(
logit, label, normalizer=fg_num_1, reduction='mean'
)
out0.retain_grads()
np.testing.assert_array_equal(
out0.numpy(),
out1.numpy(),
)
out0.backward()
self.assertEqual(out0.shape, [])
self.assertEqual(out1.shape, [])
self.assertEqual(out0.grad.shape, [])
self.assertEqual(logit.grad.shape, [2, 3])
class TestLossAPIStatic(unittest.TestCase):
def setUp(self):
paddle.enable_static()
self.exe = paddle.static.Executor()
@prog_scope()
def test_sigmoid_focal_loss(self):
logit = paddle.rand([2, 3])
logit.stop_gradient = False
label = paddle.randint(0, 1, [2, 3]).astype('float32')
label.stop_gradient = False
fg_num_0 = paddle.full([], 2.0)
fg_num_1 = paddle.full([1], 2.0)
out0 = F.sigmoid_focal_loss(
logit, label, normalizer=fg_num_0, reduction='mean'
)
out1 = F.sigmoid_focal_loss(
logit, label, normalizer=fg_num_1, reduction='mean'
)
paddle.static.append_backward(out0.sum())
prog = paddle.static.default_main_program()
res = self.exe.run(
prog, fetch_list=[out0, out1, out0.grad_name, logit.grad_name]
)
np.testing.assert_allclose(res[0], res[1])
# because static use paddle.mean
# self.assertEqual(res[0].shape, ())
# self.assertEqual(res[1].shape, ())
# self.assertEqual(res[2].shape, ())
self.assertEqual(res[3].shape, (2, 3))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -256,7 +256,10 @@ def jac(grad_fn, f, inputs): ...@@ -256,7 +256,10 @@ def jac(grad_fn, f, inputs):
_vs = vs.copy() _vs = vs.copy()
_vs[i] = _v _vs[i] = _v
_, grads = grad_fn(f, inputs, _vs) _, grads = grad_fn(f, inputs, _vs)
d_outs = paddle.concat([d_out.flatten() for d_out in grads]) if isinstance(grads, typing.Sequence):
d_outs = paddle.concat([d_out.flatten() for d_out in grads])
else:
d_outs = grads.flatten()
JJ_cols.append(d_outs) JJ_cols.append(d_outs)
# JJ is the fully unrolled jacobian # JJ is the fully unrolled jacobian
JJ = paddle.stack(JJ_cols) JJ = paddle.stack(JJ_cols)
......
...@@ -158,7 +158,7 @@ class TestConvertShapeCompare(unittest.TestCase): ...@@ -158,7 +158,7 @@ class TestConvertShapeCompare(unittest.TestCase):
fetch_list=[eq_out, not_eq_out, long_eq_out], fetch_list=[eq_out, not_eq_out, long_eq_out],
) )
np.testing.assert_array_equal( np.testing.assert_array_equal(
np.array(x_y_eq_out), np.array([[True], [False], [False]]) np.array(x_y_eq_out), np.array([True, False, False])
) )
set_a_zero = np.ones([3, 2]).astype(np.float32) set_a_zero = np.ones([3, 2]).astype(np.float32)
...@@ -168,7 +168,7 @@ class TestConvertShapeCompare(unittest.TestCase): ...@@ -168,7 +168,7 @@ class TestConvertShapeCompare(unittest.TestCase):
fetch_list=[eq_out, not_eq_out, long_eq_out], fetch_list=[eq_out, not_eq_out, long_eq_out],
) )
np.testing.assert_array_equal( np.testing.assert_array_equal(
np.array(x_y_not_eq_out), np.array([[False], [True], [True]]) np.array(x_y_not_eq_out), np.array([False, True, True])
) )
paddle.disable_static() paddle.disable_static()
......
...@@ -573,7 +573,7 @@ class TestLACModel(unittest.TestCase): ...@@ -573,7 +573,7 @@ class TestLACModel(unittest.TestCase):
words, targets, length = batch words, targets, length = batch
start_time = time.time() start_time = time.time()
avg_cost, crf_decode = model(words, targets, length) avg_cost, crf_decode = model(words, targets, length)
loss_data.append(avg_cost.numpy()[0]) loss_data.append(float(avg_cost))
# backward and optimization # backward and optimization
avg_cost.backward() avg_cost.backward()
......
...@@ -100,7 +100,7 @@ class TestPureFP16(TestMNIST): ...@@ -100,7 +100,7 @@ class TestPureFP16(TestMNIST):
scaled.backward() scaled.backward()
scaler.minimize(optimizer, scaled) scaler.minimize(optimizer, scaled)
loss_data.append(avg_loss.numpy()[0]) loss_data.append(float(avg_loss))
# save checkpoint # save checkpoint
mnist.clear_gradients() mnist.clear_gradients()
if batch_id % 2 == 0: if batch_id % 2 == 0:
......
...@@ -176,7 +176,7 @@ def train(args, place, to_static): ...@@ -176,7 +176,7 @@ def train(args, place, to_static):
state, reward, done, _ = env.step(action) state, reward, done, _ = env.step(action)
# log loss_probs # log loss_probs
loss_data.append(loss.numpy()[0]) loss_data.append(float(loss))
policy.rewards.append(reward) policy.rewards.append(reward)
ep_reward += reward ep_reward += reward
...@@ -191,7 +191,7 @@ def train(args, place, to_static): ...@@ -191,7 +191,7 @@ def train(args, place, to_static):
if i_episode % args.log_interval == 0: if i_episode % args.log_interval == 0:
print( print(
'Episode {}\tLast reward: {:.2f}\tAverage reward: {:.2f}\t loss_probs: {}'.format( 'Episode {}\tLast reward: {:.2f}\tAverage reward: {:.2f}\t loss_probs: {}'.format(
i_episode, ep_reward, running_reward, loss.numpy()[0] i_episode, ep_reward, running_reward, float(loss)
) )
) )
......
...@@ -86,7 +86,7 @@ def train(to_static, build_strategy=None): ...@@ -86,7 +86,7 @@ def train(to_static, build_strategy=None):
scaler.minimize(optimizer, scaled) scaler.minimize(optimizer, scaled)
resnet.clear_gradients() resnet.clear_gradients()
loss_data.append(avg_loss.numpy()[0]) loss_data.append(float(avg_loss))
total_loss += avg_loss total_loss += avg_loss
total_acc1 += acc_top1 total_acc1 += acc_top1
total_acc5 += acc_top5 total_acc5 += acc_top5
......
...@@ -261,7 +261,7 @@ def train_dygraph(args, batch_generator): ...@@ -261,7 +261,7 @@ def train_dygraph(args, batch_generator):
transformer.clear_gradients() transformer.clear_gradients()
if step_idx % args.print_step == 0: if step_idx % args.print_step == 0:
total_avg_cost = avg_cost.numpy() * trainer_count total_avg_cost = avg_cost.numpy() * trainer_count
avg_loss.append(total_avg_cost[0]) avg_loss.append(float(total_avg_cost))
if step_idx == 0: if step_idx == 0:
logging.info( logging.info(
"step_idx: %d, epoch: %d, batch: %d, avg loss: %f, " "step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册