未验证 提交 3beafff2 编写于 作者: E emailweixu 提交者: GitHub

Merge pull request #8415 from emailweixu/print_op

Make print_op able to show the value of bool tensor
...@@ -314,7 +314,6 @@ EIGEN_FUNCTOR(Div, EIGEN_DIV); ...@@ -314,7 +314,6 @@ EIGEN_FUNCTOR(Div, EIGEN_DIV);
template <typename DeviceContext, typename T, typename functor, template <typename DeviceContext, typename T, typename functor,
typename broadcastfunctor, typename broadcast2functor> typename broadcastfunctor, typename broadcast2functor>
void ElementwiseGradCompute(const framework::ExecutionContext& ctx, void ElementwiseGradCompute(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* x,
const framework::Tensor* y, const framework::Tensor* y,
const framework::Tensor* out, const framework::Tensor* out,
......
...@@ -46,7 +46,7 @@ struct Formater { ...@@ -46,7 +46,7 @@ struct Formater {
} }
private: private:
void PrintMessage() { CLOG << std::time(nullptr) << "\t" << message; } void PrintMessage() { CLOG << std::time(nullptr) << "\t" << message << "\t"; }
void PrintName() { void PrintName() {
if (!name.empty()) { if (!name.empty()) {
CLOG << "Tensor[" << name << "]" << std::endl; CLOG << "Tensor[" << name << "]" << std::endl;
...@@ -85,15 +85,16 @@ struct Formater { ...@@ -85,15 +85,16 @@ struct Formater {
// print float // print float
if (dtype.hash_code() == typeid(float).hash_code()) { if (dtype.hash_code() == typeid(float).hash_code()) {
Display<float>(size); Display<float>(size);
} } else if (dtype.hash_code() == typeid(double).hash_code()) {
if (dtype.hash_code() == typeid(double).hash_code()) {
Display<double>(size); Display<double>(size);
} } else if (dtype.hash_code() == typeid(int).hash_code()) {
if (dtype.hash_code() == typeid(int).hash_code()) {
Display<int>(size); Display<int>(size);
} } else if (dtype.hash_code() == typeid(int64_t).hash_code()) {
if (dtype.hash_code() == typeid(int64_t).hash_code()) {
Display<int64_t>(size); Display<int64_t>(size);
} else if (dtype.hash_code() == typeid(bool).hash_code()) {
Display<bool>(size);
} else {
CLOG << "\tdata: unprintable type: " << dtype.name() << std::endl;
} }
} }
...@@ -182,6 +183,7 @@ class TensorPrintOp : public framework::OperatorBase { ...@@ -182,6 +183,7 @@ class TensorPrintOp : public framework::OperatorBase {
} }
Formater formater; Formater formater;
formater.message = Attr<std::string>("message");
if (Attr<bool>("print_tensor_name")) { if (Attr<bool>("print_tensor_name")) {
formater.name = printed_var_name; formater.name = printed_var_name;
} }
......
...@@ -174,7 +174,7 @@ def Print(input, ...@@ -174,7 +174,7 @@ def Print(input,
print_tensor_type (bool): Print the tensor type. print_tensor_type (bool): Print the tensor type.
print_tensor_shape (bool): Print the tensor shape. print_tensor_shape (bool): Print the tensor shape.
print_tensor_lod (bool): Print the tensor lod. print_tensor_lod (bool): Print the tensor lod.
print_phase (bool): Which phase to displace, including 'forward', print_phase (str): Which phase to displace, including 'forward',
'backward' and 'both'. If set to 'backward' or 'both', will 'backward' and 'both'. If set to 'backward' or 'both', will
print the gradients of input tensor. print the gradients of input tensor.
......
...@@ -2070,7 +2070,7 @@ def reduce_sum(input, dim=None, keep_dim=False, name=None): ...@@ -2070,7 +2070,7 @@ def reduce_sum(input, dim=None, keep_dim=False, name=None):
Tensor variable with a single element, otherwise must be in the Tensor variable with a single element, otherwise must be in the
range :math:`[-rank(input), rank(input))`. If :math:`dim < 0`, range :math:`[-rank(input), rank(input))`. If :math:`dim < 0`,
the dimension to reduce is :math:`rank + dim`. the dimension to reduce is :math:`rank + dim`.
keep_dim (bool): Whether to reserve the reduced dimension in the keep_dim (bool|False): Whether to reserve the reduced dimension in the
output Tensor. The result tensor will have one fewer dimension output Tensor. The result tensor will have one fewer dimension
than the :attr:`input` unless :attr:`keep_dim` is true. than the :attr:`input` unless :attr:`keep_dim` is true.
name(str|None): A name for this layer(optional). If set None, the layer name(str|None): A name for this layer(optional). If set None, the layer
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册