提交 822f2834 编写于 作者: S sweetsky0901

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into my_unpool_max_2d

......@@ -98,7 +98,6 @@ $y = \max(x, 0)$
}
};
template <typename AttrType>
class LeakyReluOpMaker : public framework::OpProtoAndCheckerMaker {
public:
LeakyReluOpMaker(framework::OpProto *proto,
......@@ -106,8 +105,7 @@ class LeakyReluOpMaker : public framework::OpProtoAndCheckerMaker {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of LeakyRelu operator");
AddOutput("Y", "Output of LeakyRelu operator");
AddAttr<AttrType>("alpha", "The small negative slope")
.SetDefault(static_cast<AttrType>(0.02f));
AddAttr<float>("alpha", "The small negative slope").SetDefault(0.02f);
AddComment(R"DOC(
LeakyRelu Activation Operator.
......@@ -117,7 +115,6 @@ $y = \max(x, \alpha * x)$
}
};
template <typename AttrType>
class SoftShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SoftShrinkOpMaker(framework::OpProto *proto,
......@@ -125,8 +122,7 @@ class SoftShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Softshrink operator");
AddOutput("Y", "Output of Softshrink operator");
AddAttr<AttrType>("lambda", "non-negative offset")
.SetDefault(static_cast<AttrType>(0.5f));
AddAttr<float>("lambda", "non-negative offset").SetDefault(0.5f);
AddComment(R"DOC(
Softshrink Activation Operator.
......@@ -173,7 +169,6 @@ $$y = x - \frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
}
};
template <typename AttrType>
class HardShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
public:
HardShrinkOpMaker(framework::OpProto *proto,
......@@ -181,8 +176,8 @@ class HardShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of HardShrink operator");
AddOutput("Y", "Output of HardShrink operator");
AddAttr<AttrType>("threshold", "The value of threshold for HardShrink")
.SetDefault(static_cast<AttrType>(0.5));
AddAttr<float>("threshold", "The value of threshold for HardShrink")
.SetDefault(0.5f);
AddComment(R"DOC(
HardShrink Activation Operator.
......@@ -308,17 +303,16 @@ $$y = \frac{x}{1 + |x|}$$
}
};
template <typename AttrType>
class BReluOpMaker : public framework::OpProtoAndCheckerMaker {
public:
BReluOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of BRelu operator");
AddOutput("Y", "Output of BRelu operator");
AddAttr<AttrType>("t_min", "The min marginal value of BRelu")
.SetDefault(static_cast<AttrType>(0));
AddAttr<AttrType>("t_max", "The max marginal value of BRelu")
.SetDefault(static_cast<AttrType>(24));
AddAttr<float>("t_min", "The min marginal value of BRelu")
.SetDefault(static_cast<float>(0));
AddAttr<float>("t_max", "The max marginal value of BRelu")
.SetDefault(static_cast<float>(24));
AddComment(R"DOC(
BRelu Activation Operator.
......@@ -328,7 +322,6 @@ $y = \max(\min(x, t_{min}), t_{max})$
}
};
template <typename AttrType>
class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SoftReluOpMaker(framework::OpProto *proto,
......@@ -336,8 +329,8 @@ class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of SoftRelu operator");
AddOutput("Y", "Output of SoftRelu operator");
AddAttr<AttrType>("threshold", "The threshold value of SoftRelu")
.SetDefault(static_cast<AttrType>(40));
AddAttr<float>("threshold", "The threshold value of SoftRelu")
.SetDefault(40.0f);
AddComment(R"DOC(
SoftRelu Activation Operator.
......@@ -347,15 +340,13 @@ $y = \ln(1 + \exp(\max(\min(x, threshold), threshold))$
}
};
template <typename AttrType>
class ELUOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ELUOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of ELU operator");
AddOutput("Y", "Output of ELU operator");
AddAttr<AttrType>("alpha", "The alpha value of ELU")
.SetDefault(static_cast<AttrType>(1.0f));
AddAttr<float>("alpha", "The alpha value of ELU").SetDefault(1.0f);
AddComment(R"DOC(
ELU Activation Operator.
......@@ -368,15 +359,14 @@ $y = \max(0, x) + \min(0, \alpha * (e^x - 1))$
}
};
template <typename AttrType>
class Relu6OpMaker : public framework::OpProtoAndCheckerMaker {
public:
Relu6OpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Relu6 operator");
AddOutput("Y", "Output of Relu6 operator");
AddAttr<AttrType>("threshold", "The threshold value of Relu6")
.SetDefault(static_cast<AttrType>(6));
AddAttr<float>("threshold", "The threshold value of Relu6")
.SetDefault(6.0f);
AddComment(R"DOC(
Relu6 Activation Operator.
......@@ -386,15 +376,13 @@ $y = \min(\max(0, x), 6)$
}
};
template <typename AttrType>
class PowOpMaker : public framework::OpProtoAndCheckerMaker {
public:
PowOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Pow operator");
AddOutput("Y", "Output of Pow operator");
AddAttr<AttrType>("factor", "The exponential factor of Pow")
.SetDefault(static_cast<AttrType>(1));
AddAttr<float>("factor", "The exponential factor of Pow").SetDefault(1.0f);
AddComment(R"DOC(
Pow Activation Operator.
......@@ -404,17 +392,16 @@ $y = x^{factor}$
}
};
template <typename AttrType>
class STanhOpMaker : public framework::OpProtoAndCheckerMaker {
public:
STanhOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of STanh operator");
AddOutput("Y", "Output of STanh operator");
AddAttr<AttrType>("scale_a", "The scale parameter of a for the input")
.SetDefault(static_cast<AttrType>(2 / 3));
AddAttr<AttrType>("scale_b", "The scale parameter of b for the input")
.SetDefault(static_cast<AttrType>(1.7159));
AddAttr<float>("scale_a", "The scale parameter of a for the input")
.SetDefault(2.0f / 3.0f);
AddAttr<float>("scale_b", "The scale parameter of b for the input")
.SetDefault(1.7159f);
AddComment(R"DOC(
STanh Activation Operator.
......@@ -424,7 +411,6 @@ $$y = b * \frac{e^{a * x} - e^{-a * x}}{e^{a * x} + e^{-a * x}}$$
}
};
template <typename AttrType>
class ThresholdedReluOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ThresholdedReluOpMaker(framework::OpProto *proto,
......@@ -432,8 +418,8 @@ class ThresholdedReluOpMaker : public framework::OpProtoAndCheckerMaker {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of ThresholdedRelu operator");
AddOutput("Y", "Output of ThresholdedRelu operator");
AddAttr<AttrType>("threshold", "The threshold location of activation")
.SetDefault(static_cast<AttrType>(1.0));
AddAttr<float>("threshold", "The threshold location of activation")
.SetDefault(1.0f);
AddComment(R"DOC(
ThresholdedRelu Activation Operator.
......@@ -448,7 +434,6 @@ $$
}
};
template <typename AttrType>
class HardSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
public:
HardSigmoidOpMaker(framework::OpProto *proto,
......@@ -456,10 +441,10 @@ class HardSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of HardSigmoid operator");
AddOutput("Y", "Output of HardSigmoid operator");
AddAttr<AttrType>("slope", "Slope for linear approximation of sigmoid")
.SetDefault(static_cast<AttrType>(0.2));
AddAttr<AttrType>("offset", "Offset for linear approximation of sigmoid")
.SetDefault(static_cast<AttrType>(0.5));
AddAttr<float>("slope", "Slope for linear approximation of sigmoid")
.SetDefault(0.2f);
AddAttr<float>("offset", "Offset for linear approximation of sigmoid")
.SetDefault(0.5f);
AddComment(R"DOC(
HardSigmoid Activation Operator.
......@@ -499,7 +484,7 @@ REGISTER_OP(tanh, ops::ActivationOp, ops::TanhOpMaker, tanh_grad,
REGISTER_OP(tanh_shrink, ops::ActivationOp, ops::TanhShrinkOpMaker,
tanh_shrink_grad, ops::ActivationOpGrad);
REGISTER_OP(softshrink, ops::ActivationOp, ops::SoftShrinkOpMaker<float>,
REGISTER_OP(softshrink, ops::ActivationOp, ops::SoftShrinkOpMaker,
softshrink_grad, ops::ActivationOpGrad);
REGISTER_OP(sqrt, ops::ActivationOp, ops::SqrtOpMaker, sqrt_grad,
......@@ -523,35 +508,34 @@ REGISTER_OP(softplus, ops::ActivationOp, ops::SoftplusOpMaker, softplus_grad,
REGISTER_OP(softsign, ops::ActivationOp, ops::SoftsignOpMaker, softsign_grad,
ops::ActivationOpGrad);
REGISTER_OP(brelu, ops::ActivationOp, ops::BReluOpMaker<float>, brelu_grad,
REGISTER_OP(brelu, ops::ActivationOp, ops::BReluOpMaker, brelu_grad,
ops::ActivationOpGrad);
REGISTER_OP(leaky_relu, ops::ActivationOp, ops::LeakyReluOpMaker<float>,
REGISTER_OP(leaky_relu, ops::ActivationOp, ops::LeakyReluOpMaker,
leaky_relu_grad, ops::ActivationOpGrad);
REGISTER_OP(soft_relu, ops::ActivationOp, ops::SoftReluOpMaker<float>,
soft_relu_grad, ops::ActivationOpGrad);
REGISTER_OP(soft_relu, ops::ActivationOp, ops::SoftReluOpMaker, soft_relu_grad,
ops::ActivationOpGrad);
REGISTER_OP(elu, ops::ActivationOp, ops::ELUOpMaker<float>, elu_grad,
REGISTER_OP(elu, ops::ActivationOp, ops::ELUOpMaker, elu_grad,
ops::ActivationOpGrad);
REGISTER_OP(relu6, ops::ActivationOp, ops::Relu6OpMaker<float>, relu6_grad,
REGISTER_OP(relu6, ops::ActivationOp, ops::Relu6OpMaker, relu6_grad,
ops::ActivationOpGrad);
REGISTER_OP(pow, ops::ActivationOp, ops::PowOpMaker<float>, pow_grad,
REGISTER_OP(pow, ops::ActivationOp, ops::PowOpMaker, pow_grad,
ops::ActivationOpGrad);
REGISTER_OP(stanh, ops::ActivationOp, ops::STanhOpMaker<float>, stanh_grad,
REGISTER_OP(stanh, ops::ActivationOp, ops::STanhOpMaker, stanh_grad,
ops::ActivationOpGrad);
REGISTER_OP(hard_shrink, ops::ActivationOp, ops::HardShrinkOpMaker<float>,
REGISTER_OP(hard_shrink, ops::ActivationOp, ops::HardShrinkOpMaker,
hard_shrink_grad, ops::ActivationOpGrad);
REGISTER_OP(thresholded_relu, ops::ActivationOp,
ops::ThresholdedReluOpMaker<float>, thresholded_relu_grad,
ops::ActivationOpGrad);
REGISTER_OP(thresholded_relu, ops::ActivationOp, ops::ThresholdedReluOpMaker,
thresholded_relu_grad, ops::ActivationOpGrad);
REGISTER_OP(hard_sigmoid, ops::ActivationOp, ops::HardSigmoidOpMaker<float>,
REGISTER_OP(hard_sigmoid, ops::ActivationOp, ops::HardSigmoidOpMaker,
hard_sigmoid_grad, ops::ActivationOpGrad);
#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, functor, grad_functor) \
......
......@@ -109,4 +109,5 @@ paramOut = param + paramUpdate$$
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(adadelta, ops::AdadeltaOp, ops::AdadeltaOpMaker);
REGISTER_OP_CPU_KERNEL(
adadelta, ops::AdadeltaOpKernel<paddle::platform::CPUPlace, float>);
adadelta, ops::AdadeltaOpKernel<paddle::platform::CPUPlace, float>,
ops::AdadeltaOpKernel<paddle::platform::CPUPlace, double>);
......@@ -17,4 +17,5 @@
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
adadelta, ops::AdadeltaOpKernel<paddle::platform::GPUPlace, float>);
adadelta, ops::AdadeltaOpKernel<paddle::platform::GPUPlace, float>,
ops::AdadeltaOpKernel<paddle::platform::GPUPlace, double>);
......@@ -33,8 +33,8 @@ class AdadeltaOpKernel : public framework::OpKernel<T> {
avg_squared_grad_out_tensor->mutable_data<T>(ctx.GetPlace());
avg_squared_update_out_tensor->mutable_data<T>(ctx.GetPlace());
float rho = ctx.Attr<float>("rho");
float epsilon = ctx.Attr<float>("epsilon");
T rho = static_cast<T>(ctx.Attr<float>("rho"));
T epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
auto param = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("Param"));
......
......@@ -14,8 +14,8 @@
#define EIGEN_USE_GPU
#include "paddle/operators/adagrad_op.h"
#include "paddle/operators/math/selected_rows_functor.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/selected_rows_functor.h"
#include "paddle/platform/cuda_helper.h"
namespace paddle {
......@@ -134,8 +134,8 @@ struct SparseAdagradFunctor<platform::GPUPlace, T> {
T, 256><<<grid2, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(grad_merge_data, grad_merge->rows().data(),
lr, param_data,
moment_data, grad_width, epsilon);
lr, param_data, moment_data, grad_width,
epsilon);
}
};
......
......@@ -127,4 +127,5 @@ paramOut = param - learningRate * moment_1/ ($\sqrt{(moment_2)} + \epsilon)$$
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(adam, ops::AdamOp, ops::AdamOpMaker);
REGISTER_OP_CPU_KERNEL(adam,
ops::AdamOpKernel<paddle::platform::CPUPlace, float>);
ops::AdamOpKernel<paddle::platform::CPUPlace, float>,
ops::AdamOpKernel<paddle::platform::CPUPlace, double>);
......@@ -17,4 +17,5 @@
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(adam,
ops::AdamOpKernel<paddle::platform::GPUPlace, float>);
ops::AdamOpKernel<paddle::platform::GPUPlace, float>,
ops::AdamOpKernel<paddle::platform::GPUPlace, double>);
......@@ -31,9 +31,9 @@ class AdamOpKernel : public framework::OpKernel<T> {
moment1_out_tensor->mutable_data<T>(ctx.GetPlace());
moment2_out_tensor->mutable_data<T>(ctx.GetPlace());
float beta1 = ctx.Attr<float>("beta1");
float beta2 = ctx.Attr<float>("beta2");
float epsilon = ctx.Attr<float>("epsilon");
T beta1 = static_cast<T>(ctx.Attr<float>("beta1"));
T beta2 = static_cast<T>(ctx.Attr<float>("beta2"));
T epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
auto param = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("Param"));
......
......@@ -126,4 +126,5 @@ division by 0 error.
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(adamax, ops::AdamaxOp, ops::AdamaxOpMaker);
REGISTER_OP_CPU_KERNEL(adamax,
ops::AdamaxOpKernel<paddle::platform::CPUPlace, float>);
ops::AdamaxOpKernel<paddle::platform::CPUPlace, float>,
ops::AdamaxOpKernel<paddle::platform::CPUPlace, double>);
......@@ -17,4 +17,5 @@
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(adamax,
ops::AdamaxOpKernel<paddle::platform::GPUPlace, float>);
ops::AdamaxOpKernel<paddle::platform::GPUPlace, float>,
ops::AdamaxOpKernel<paddle::platform::GPUPlace, double>);
......@@ -31,9 +31,9 @@ class AdamaxOpKernel : public framework::OpKernel<T> {
moment_out_tensor->mutable_data<T>(ctx.GetPlace());
inf_norm_out_tensor->mutable_data<T>(ctx.GetPlace());
float beta1 = ctx.Attr<float>("beta1");
float beta2 = ctx.Attr<float>("beta2");
float epsilon = ctx.Attr<float>("epsilon");
T beta1 = static_cast<T>(ctx.Attr<float>("beta1"));
T beta2 = static_cast<T>(ctx.Attr<float>("beta2"));
T epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
auto param = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("Param"));
......
......@@ -179,7 +179,9 @@ REGISTER_OP(sequence_conv, ops::SequenceConvOp, ops::SequenceConvOpMaker,
sequence_conv_grad, ops::SequenceConvGradOp);
REGISTER_OP_CPU_KERNEL(
sequence_conv, ops::SequenceConvKernel<paddle::platform::CPUPlace, float>);
sequence_conv, ops::SequenceConvKernel<paddle::platform::CPUPlace, float>,
ops::SequenceConvKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(
sequence_conv_grad,
ops::SequenceConvGradKernel<paddle::platform::CPUPlace, float>);
ops::SequenceConvGradKernel<paddle::platform::CPUPlace, float>,
ops::SequenceConvGradKernel<paddle::platform::CPUPlace, double>);
......@@ -16,7 +16,9 @@
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
sequence_conv, ops::SequenceConvKernel<paddle::platform::GPUPlace, float>);
sequence_conv, ops::SequenceConvKernel<paddle::platform::GPUPlace, float>,
ops::SequenceConvKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_GPU_KERNEL(
sequence_conv_grad,
ops::SequenceConvGradKernel<paddle::platform::GPUPlace, float>);
ops::SequenceConvGradKernel<paddle::platform::GPUPlace, float>,
ops::SequenceConvGradKernel<paddle::platform::GPUPlace, double>);
......@@ -248,7 +248,7 @@ def data(name,
stop_gradient=stop_gradient)
def create_tensor(dtype, name=None, main_program=None):
def create_tensor(dtype, name=None, main_program=None, startup_program=None):
helper = LayerHelper("create_tensor", **locals())
return helper.create_variable(name=helper.name, dtype=dtype)
......@@ -412,30 +412,12 @@ _create_op_func_('mul')
_create_op_func_('elementwise_add')
_create_op_func_('dropout')
_create_op_func_('reshape')
_create_op_func_('elementwise_add')
_create_op_func_('sigmoid')
_create_op_func_('scale')
_create_op_func_('reshape')
_create_op_func_('transpose')
def fill_constant(data_type, shape, value=None, program=None):
"""
This function creates a tensor , with shape as mentioned in the input and
specified data_type and fills this up with a constant value that
comes in the input.
"""
helper = LayerHelper('fill_constant', **locals())
out = helper.create_tmp_variable(dtype=data_type)
helper.append_op(
type='fill_constant',
outputs={'Out': [out]},
attrs={'data_type': data_type,
'shape': shape,
'value': value})
return out
def cast(x, data_type, main_program=None):
"""
This function takes in the input with input_data_type
......@@ -478,7 +460,7 @@ def sums(input, main_program=None, startup_program=None):
return out
def assign(input, output, main_program=None):
def assign(input, output, main_program=None, startup_program=None):
helper = LayerHelper('assign', **locals())
helper.append_op(
type='scale',
......@@ -490,7 +472,7 @@ def assign(input, output, main_program=None):
def split_lod_tensor(input,
mask,
level,
level=0,
main_program=None,
startup_program=None):
helper = LayerHelper('split_lod_tensor', **locals())
......@@ -512,11 +494,11 @@ def merge_lod_tensor(in_true,
in_false,
x,
mask,
level,
level=0,
main_program=None,
startup_program=None):
helper = LayerHelper('merge_lod_tensor', **locals())
out = helper.create_tmp_variable(dtype=x.data_type)
out = helper.create_tmp_variable(dtype=in_true.data_type)
helper.append_op(
type='merge_lod_tensor',
inputs={'X': x,
......@@ -1366,7 +1348,7 @@ def array_to_lod_tensor(x, table, main_program=None):
return tmp
def fill_constant(shape, dtype, value, main_program=None):
def fill_constant(shape, dtype, value, main_program=None, startup_program=None):
"""
This function creates a tensor , with shape as mentioned in the input and
specified data_type and fills this up with a constant value that
......@@ -1387,6 +1369,31 @@ def fill_constant(shape, dtype, value, main_program=None):
return out
def fill_constant_batch_size_like(input,
shape,
dtype,
value,
input_dim_idx=0,
output_dim_idx=0,
main_program=None,
startup_program=None):
helper = LayerHelper("fill_constant_batch_size_like", **locals())
out = helper.create_tmp_variable(dtype=dtype)
helper.append_op(
type='fill_constant_batch_size_like',
inputs={'Input': input},
outputs={'Out': [out]},
attrs={
'shape': shape,
'data_type': out.data_type,
'value': float(value),
'input_dim_idx': input_dim_idx,
'output_dim_idx': output_dim_idx
})
out.stop_gradient = True
return out
def ones(shape, dtype, main_program=None):
"""
This function performs the same function as fill_constant() declared above
......@@ -1449,7 +1456,7 @@ def create_array(dtype, main_program=None):
dtype=dtype)
def less_than(x, y, cond=None, main_program=None):
def less_than(x, y, cond=None, main_program=None, **ignored):
helper = LayerHelper("less_than", **locals())
if cond is None:
cond = helper.create_tmp_variable(dtype='bool')
......@@ -1527,13 +1534,20 @@ class ConditionalBlockGuard(BlockGuard):
class ConditionalBlock(object):
def __init__(self, inputs, name=None, main_program=None):
def __init__(self,
inputs,
name=None,
main_program=None,
startup_program=None):
for each_input in inputs:
if not isinstance(each_input, Variable):
raise TypeError("Each input should be variable")
self.inputs = inputs
self.helper = LayerHelper(
'conditional_block', name=name, main_program=main_program)
'conditional_block',
name=name,
main_program=main_program,
startup_program=startup_program)
def block(self):
return ConditionalBlockGuard(self)
......@@ -1578,3 +1592,148 @@ class ConditionalBlock(object):
outputs={'Out': out_list,
'Scope': [step_scope]},
attrs={'block': inside_block})
class IfElseBlockGuard(object):
def __init__(self, is_true, ifelse):
if not isinstance(ifelse, IfElse):
raise TypeError("ifelse must be an instance of IfElse class")
if ifelse.status != IfElse.OUT_IF_ELSE_BLOCKS:
raise ValueError("You cannot invoke IfElse.block() inside a block")
self.is_true = is_true
self.ie = ifelse
if is_true:
self.cond_block = ifelse.conditional_true_block
else:
self.cond_block = ifelse.conditional_false_block
if not isinstance(self.cond_block, ConditionalBlock):
raise TypeError("Unexpected situation")
self.cond_block = self.cond_block.block()
def __enter__(self):
self.ie.status = IfElse.IN_IF_ELSE_TRUE_BLOCKS if self.is_true else IfElse.IN_IF_ELSE_FALSE_BLOCKS
self.cond_block.__enter__()
def __exit__(self, exc_type, exc_val, exc_tb):
if not self.cond_block.__exit__(exc_type, exc_val, exc_tb):
# re-raise inside exception
return False
if len(self.ie.output_table[1 if self.is_true else 0]) == 0:
raise ValueError("Must set output inside block")
self.ie.status = IfElse.OUT_IF_ELSE_BLOCKS
class IfElse(object):
OUT_IF_ELSE_BLOCKS = 0
IN_IF_ELSE_TRUE_BLOCKS = 1
IN_IF_ELSE_FALSE_BLOCKS = 2
def __init__(self, cond, name=None, main_program=None,
startup_program=None):
if not isinstance(cond, Variable):
raise TypeError("cond must be a Variable")
self.helper = LayerHelper(
'ifelse',
name=name,
main_program=main_program,
startup_program=startup_program)
self.cond = cond
self.input_table = {}
self.status = IfElse.OUT_IF_ELSE_BLOCKS
self.conditional_true_block = ConditionalBlock(inputs=[self.cond])
self.conditional_false_block = ConditionalBlock(inputs=[self.cond])
self.output_table = ([], []) # (true_outs, false_outs)
def input(self, x):
if self.status == IfElse.OUT_IF_ELSE_BLOCKS:
raise ValueError("input must in true/false blocks")
if id(x) not in self.input_table:
parent_block = self.parent_block()
out_true = parent_block.create_var(
name=unique_name('ifelse_input' + self.helper.name),
dtype=x.data_type)
out_false = parent_block.create_var(
name=unique_name('ifelse_input' + self.helper.name),
dtype=x.data_type)
parent_block.append_op(
type='split_lod_tensor',
inputs={
'X': x,
'Mask': self.cond,
},
outputs={'OutTrue': out_true,
'OutFalse': out_false},
attrs={'level': 0})
self.input_table[id(x)] = (out_true, out_false)
else:
out_true, out_false = self.input_table[id(x)]
if self.status == IfElse.IN_IF_ELSE_TRUE_BLOCKS:
return out_true
else:
return out_false
def parent_block(self):
current_block = self.helper.main_program.current_block()
return self.helper.main_program.block(current_block.parent_idx)
def true_block(self):
return IfElseBlockGuard(True, self)
def false_block(self):
return IfElseBlockGuard(False, self)
def output(self, *outs):
if self.status == self.OUT_IF_ELSE_BLOCKS:
raise ValueError("output can only be invoked in the sub-block")
out_table = self.output_table[1 if self.status ==
self.IN_IF_ELSE_TRUE_BLOCKS else 0]
parent_block = self.parent_block()
for each_out in outs:
if not isinstance(each_out, Variable):
raise TypeError("Each output should be a variable")
# create outside tensor
outside_out = parent_block.create_var(
name=unique_name("_".join([self.helper.name, 'output'])),
dtype=each_out.data_type)
out_table.append(outside_out)
# assign local var to outside
assign(
input=each_out,
output=outside_out,
main_program=self.helper.main_program,
startup_program=self.helper.startup_program)
def __call__(self):
if self.status != self.OUT_IF_ELSE_BLOCKS:
raise ValueError("IfElse::__call__ must be out of sub-block")
false_len, true_len = map(len, self.output_table)
if false_len == 0 and true_len == 0:
raise ValueError("Must invoke true_block/false_block before "
"__call__")
elif false_len != true_len and false_len != 0 and true_len != 0:
raise ValueError("The output side must be same")
elif false_len == 0 or true_len == 0:
return self.output_table[0 if false_len != 0 else 1]
# else none of false_len/true_len is zero
# merge together
rlist = []
for false_var, true_var in zip(*self.output_table):
rlist.append(
merge_lod_tensor(
in_true=true_var,
in_false=false_var,
mask=self.cond,
x=self.cond,
level=0,
main_program=self.helper.main_program,
startup_program=self.helper.startup_program))
return rlist
import paddle.v2.fluid.layers as layers
from paddle.v2.fluid.framework import Program
from paddle.v2.fluid.executor import Executor
from paddle.v2.fluid.optimizer import MomentumOptimizer
import paddle.v2.fluid.core as core
import paddle.v2 as paddle
import unittest
import numpy as np
class TestMNISTIfElseOp(unittest.TestCase):
def test_raw_api(self):
kwargs = {'startup_program': Program(), 'main_program': Program()}
image = layers.data(
name='x', shape=[784], data_type='float32', **kwargs)
label = layers.data(name='y', shape=[1], data_type='int64', **kwargs)
limit = layers.fill_constant_batch_size_like(
input=label, dtype='int64', shape=[1], value=5.0, **kwargs)
cond = layers.less_than(x=label, y=limit, **kwargs)
true_image, false_image = layers.split_lod_tensor(
input=image, mask=cond, **kwargs)
true_out = layers.create_tensor(dtype='float32', **kwargs)
true_cond = layers.ConditionalBlock([true_image], **kwargs)
with true_cond.block():
hidden = layers.fc(input=true_image, size=100, act='tanh', **kwargs)
prob = layers.fc(input=hidden, size=10, act='softmax', **kwargs)
layers.assign(input=prob, output=true_out, **kwargs)
false_out = layers.create_tensor(dtype='float32', **kwargs)
false_cond = layers.ConditionalBlock([false_image], **kwargs)
with false_cond.block():
hidden = layers.fc(input=false_image,
size=200,
act='tanh',
**kwargs)
prob = layers.fc(input=hidden, size=10, act='softmax', **kwargs)
layers.assign(input=prob, output=false_out, **kwargs)
prob = layers.merge_lod_tensor(
in_true=true_out, in_false=false_out, mask=cond, x=image, **kwargs)
loss = layers.cross_entropy(input=prob, label=label, **kwargs)
avg_loss = layers.mean(x=loss, **kwargs)
optimizer = MomentumOptimizer(learning_rate=0.001, momentum=0.9)
optimizer.minimize(avg_loss, kwargs['startup_program'])
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=8192),
batch_size=200)
place = core.CPUPlace()
exe = Executor(place)
exe.run(kwargs['startup_program'])
PASS_NUM = 100
for pass_id in range(PASS_NUM):
for data in train_reader():
x_data = np.array(map(lambda x: x[0], data)).astype("float32")
y_data = np.array(map(lambda x: x[1], data)).astype("int64")
y_data = np.expand_dims(y_data, axis=1)
tensor_x = core.LoDTensor()
tensor_x.set(x_data, place)
tensor_y = core.LoDTensor()
tensor_y.set(y_data, place)
outs = map(np.array,
exe.run(kwargs['main_program'],
feed={'x': tensor_x,
'y': tensor_y},
fetch_list=[avg_loss]))
print outs[0]
if outs[0] < 1.0:
return
self.assertFalse(True)
def test_ifelse(self):
kwargs = {'startup_program': Program(), 'main_program': Program()}
image = layers.data(
name='x', shape=[784], data_type='float32', **kwargs)
label = layers.data(name='y', shape=[1], data_type='int64', **kwargs)
limit = layers.fill_constant_batch_size_like(
input=label, dtype='int64', shape=[1], value=5.0, **kwargs)
cond = layers.less_than(x=label, y=limit, **kwargs)
ie = layers.IfElse(cond, **kwargs)
with ie.true_block():
true_image = ie.input(image)
hidden = layers.fc(input=true_image, size=100, act='tanh', **kwargs)
prob = layers.fc(input=hidden, size=10, act='softmax', **kwargs)
ie.output(prob)
with ie.false_block():
false_image = ie.input(image)
hidden = layers.fc(input=false_image,
size=200,
act='tanh',
**kwargs)
prob = layers.fc(input=hidden, size=10, act='softmax', **kwargs)
ie.output(prob)
prob = ie()
loss = layers.cross_entropy(input=prob[0], label=label, **kwargs)
avg_loss = layers.mean(x=loss, **kwargs)
optimizer = MomentumOptimizer(learning_rate=0.001, momentum=0.9)
optimizer.minimize(avg_loss, kwargs['startup_program'])
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=8192),
batch_size=200)
place = core.CPUPlace()
exe = Executor(place)
exe.run(kwargs['startup_program'])
PASS_NUM = 100
for pass_id in range(PASS_NUM):
for data in train_reader():
x_data = np.array(map(lambda x: x[0], data)).astype("float32")
y_data = np.array(map(lambda x: x[1], data)).astype("int64")
y_data = np.expand_dims(y_data, axis=1)
tensor_x = core.LoDTensor()
tensor_x.set(x_data, place)
tensor_y = core.LoDTensor()
tensor_y.set(y_data, place)
outs = map(np.array,
exe.run(kwargs['main_program'],
feed={'x': tensor_x,
'y': tensor_y},
fetch_list=[avg_loss]))
print outs[0]
if outs[0] < 1.0:
return
self.assertFalse(True)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册