diff --git a/CMakeLists.txt b/CMakeLists.txt index ae8728f4d4c22f45f13a283a448e907337f37f7a..65164b8472b902be8b0b9d5fb99807d012b8a666 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -109,7 +109,7 @@ else() endif() set(WITH_MKLML ${WITH_MKL}) -if (WITH_MKL AND ${AVX2_FOUND}) +if (WITH_MKL AND AVX2_FOUND) set(WITH_MKLDNN ON) else() message(STATUS "Do not have AVX2 intrinsics and disabled MKL-DNN") diff --git a/doc/api/v2/config/layer.rst b/doc/api/v2/config/layer.rst index b2b55ec419d2f8453e067f202f6c1b7da6c201de..d4d182f6692e09b3e40f3620b77d9a0f20ec5af3 100644 --- a/doc/api/v2/config/layer.rst +++ b/doc/api/v2/config/layer.rst @@ -382,6 +382,11 @@ cos_sim .. autoclass:: paddle.v2.layer.cos_sim :noindex: +l2_distance +----------- +.. autoclass:: paddle.v2.layer.l2_distance + :noindex: + trans ----- .. autoclass:: paddle.v2.layer.trans diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index 00d9dd238ec5328be28f58f8118daad3a039e08c..b9018ecdba8303fd6b37c87edd99e192aa604228 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -513,19 +513,14 @@ ParamGradInfoMap AppendBackward( const int root_block_idx = 0; auto root_block = program_desc.MutableBlock(root_block_idx); - // insert fill one op for target - // TODO(qiao) add some check to the target. std::string fill_one_op_out = GradVarName(target.Name()); - std::vector target_shape_desc = target.Shape(); - std::vector target_shape; - std::transform(target_shape_desc.begin(), target_shape_desc.end(), - std::back_inserter(target_shape), - [](int64_t dim) { return static_cast(dim); }); + bool is_scalar = target.Shape() == std::vector{1}; + PADDLE_ENFORCE(is_scalar, "target should be scalar"); VLOG(3) << "backward from loss=" << target.Name() << " data_type=" << target.GetDataType(); std::unique_ptr fill_one_op( new OpDescBind("fill_constant", {}, {{"Out", {fill_one_op_out}}}, - {{"shape", target_shape}, + {{"shape", std::vector{1}}, {"value", static_cast(1.0)}, {"data_type", target.GetDataType()}})); // infer var type of fill_one_op diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index d485cdf6109274377ad0057223bdd8401e964aa7..2b858f5ea0874d7bf1a9cf38529f5d0d70cca7f2 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -508,6 +508,7 @@ TEST(Backward, simple_single_op) { op->SetOutput("Out", {"out"}); auto target = f::VarDescBind("out"); + target.SetShape({1}); auto var_to_grad = AppendBackward(program, target, {}); ASSERT_EQ(block->AllOps().size(), 3UL); @@ -544,6 +545,7 @@ TEST(Backward, default_attribute) { op->CheckAttrs(); auto target = f::VarDescBind("out"); + target.SetShape({1}); AppendBackward(program, target, {}); ASSERT_EQ(block->AllOps().size(), 3UL); @@ -581,6 +583,7 @@ TEST(Backward, simple_mult_op) { op3->SetOutput("Out", {"out3"}); auto target = f::VarDescBind("out3"); + target.SetShape({1}); size_t forward_len = block->AllOps().size(); auto var_to_grad = AppendBackward(program, target, {}); @@ -670,6 +673,7 @@ TEST(Backward, intermedia_var_no_grad) { op4->SetOutput("Out", {"out4"}); auto target = f::VarDescBind("out4"); + target.SetShape({1}); size_t forward_len = block->AllOps().size(); auto var_to_grad = AppendBackward(program, target, {"out3"}); @@ -730,6 +734,7 @@ TEST(Backward, var_no_grad) { op2->SetOutput("Z", {"z2"}); auto target = f::VarDescBind("z2"); + target.SetShape({1}); size_t forward_len = block->AllOps().size(); auto var_to_grad = AppendBackward(program, target, {"z1"}); @@ -810,6 +815,7 @@ TEST(Backward, shared_var) { op3->SetOutput("Out", {"out3"}); auto target = f::VarDescBind("out3"); + target.SetShape({1}); size_t forward_len = block->AllOps().size(); auto var_to_grad = AppendBackward(program, target, {}); @@ -888,6 +894,7 @@ TEST(Backward, half_backward) { op1->SetOutput("Out", {"out"}); auto target = f::VarDescBind("out"); + target.SetShape({1}); size_t forward_len = block->AllOps().size(); auto var_to_grad = AppendBackward(program, target, {"b"}); f::OpDescBind *fill_op = block->AllOps()[forward_len]; diff --git a/paddle/framework/data_type.h b/paddle/framework/data_type.h index be144d8fc0104fccc08006532a85906ade25c2a1..c54d2d4ddf09c445fb25c1fbe8a7498f233d8212 100644 --- a/paddle/framework/data_type.h +++ b/paddle/framework/data_type.h @@ -46,6 +46,8 @@ inline std::type_index ToTypeIndex(DataType type) { return typeid(int); case DataType::INT64: return typeid(int64_t); + case DataType::BOOL: + return typeid(bool); default: PADDLE_THROW("Not support type %d", type); } @@ -66,6 +68,9 @@ inline void VisitDataType(DataType type, Visitor visitor) { case DataType::INT64: visitor.template operator()(); break; + case DataType::BOOL: + visitor.template operator()(); + break; default: PADDLE_THROW("Not supported"); } diff --git a/paddle/gserver/layers/L2DistanceLayer.cpp b/paddle/gserver/layers/L2DistanceLayer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c71df1b92cef9b19001a0984953a260fbdd1d762 --- /dev/null +++ b/paddle/gserver/layers/L2DistanceLayer.cpp @@ -0,0 +1,91 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "L2DistanceLayer.h" +#include "paddle/utils/Logging.h" +#include "paddle/utils/Stat.h" + +namespace paddle { + +REGISTER_LAYER(l2_distance, L2DistanceLayer); + +bool L2DistanceLayer::init(const LayerMap& layerMap, + const ParameterMap& parameterMap) { + /* Initialize the basic parent class */ + Layer::init(layerMap, parameterMap); + + CHECK_EQ(inputLayers_.size(), 2UL) << "The L2DistanceLayer accepts two and " + << "only two inputs."; + CHECK_EQ(getSize(), 1UL) << "The output dimensionality of L2DistanceLayer " + << "is fixed to be 1."; + + return true; +} + +void L2DistanceLayer::forward(PassType passType) { + Layer::forward(passType); + + const auto inV1 = getInputValue(0); + const auto inV2 = getInputValue(1); + + CHECK(inV1 && inV2); + CHECK_EQ(inV1->getHeight(), inV2->getHeight()) + << "The height of two inputs of this layer must be the same."; + CHECK_EQ(inV1->getWidth(), inV2->getWidth()) + << "The width of two inputs of this layer must be the same."; + + int batchSize = inV1->getHeight(); + int output_dim = getSize(); + { + REGISTER_TIMER_INFO("L2DistanceBpAtvTimer", getName().c_str()); + reserveOutput(batchSize, output_dim); + auto outV = getOutputValue(); + CHECK(outV) << "The output matrix should not be null."; + + Matrix::resizeOrCreate( + inputSub_, inV1->getHeight(), inV1->getWidth(), false, useGpu_); + + inputSub_->assign(*inV1); + inputSub_->sub(*inV2); + outV->sumOfProducts(*inputSub_, *inputSub_, 1, 0); + outV->sqrt2(*outV); + } +} + +void L2DistanceLayer::backward(const UpdateCallback& callback) { + const auto outG = getOutputGrad(); + const auto outV = getOutputValue(); + CHECK(outG && outV); + + auto inGrad1 = getInputGrad(0); + auto inGrad2 = getInputGrad(1); + + { + REGISTER_TIMER_INFO("L2DistanceBpAtvTimer", getName().c_str()); + + if (inGrad1 || inGrad2) { + outV->scalarDiv(*outV, 1.); + outV->dotMul(*outG, *outV); + } + + if (inGrad1) inGrad1->addRowScale(0, *inputSub_, *outV); + + if (inGrad2) { + inputSub_->mulScalar(-1.); + inGrad2->addRowScale(0, *inputSub_, *outV); + } + } +} + +} // namespace paddle diff --git a/paddle/gserver/layers/L2DistanceLayer.h b/paddle/gserver/layers/L2DistanceLayer.h new file mode 100644 index 0000000000000000000000000000000000000000..9b12847a10e64a713635c0df079507b23a73c257 --- /dev/null +++ b/paddle/gserver/layers/L2DistanceLayer.h @@ -0,0 +1,52 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "Layer.h" +#include "paddle/math/Matrix.h" + +namespace paddle { + +/** + * @brief The layer calculates the l2 distance between two input vectors. + * \f[ + * f(\bf{x}, \bf{y}) = \sqrt{\sum_{i=1}^D(x_i - y_i)} + * \f] + * + * - Input1: A vector (batchSize * dataDim) + * - Input2: A vector (batchSize * dataDim) + * - Output: A vector (batchSize * 1) + * + * The configuration api is: l2_distance_layer. + */ + +class L2DistanceLayer : public Layer { +public: + explicit L2DistanceLayer(const LayerConfig& config) : Layer(config) {} + ~L2DistanceLayer() {} + + bool init(const LayerMap& layerMap, + const ParameterMap& parameterMap) override; + + void forward(PassType passType) override; + void backward(const UpdateCallback& callback = nullptr) override; + +private: + // Store the result of subtracting Input2 from Input1 in forward computation, + // which will be reused in backward computation. + MatrixPtr inputSub_; +}; + +} // namespace paddle diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp index fb4eea6f67da9078ef43268a3a1603dc6ccfa652..cacf10692942f5eca2f6c498183f4acc00768460 100644 --- a/paddle/gserver/tests/test_LayerGrad.cpp +++ b/paddle/gserver/tests/test_LayerGrad.cpp @@ -583,6 +583,7 @@ TEST(Layer, maxoutLayer) { testLayerGrad(config, "maxout", 10, false, useGpu); } } + void testFcLayer(string format, size_t nnz) { TestConfig config; config.biasSize = 1024; @@ -2444,6 +2445,25 @@ TEST(Layer, ScaleSubRegionLayer) { } } +TEST(Layer, L2DistanceLayer) { + TestConfig config; + config.layerConfig.set_type("l2_distance"); + config.layerConfig.set_size(1); + config.biasSize = 0; + + const size_t input_dim = 27; + const size_t batch_size = 11; + + config.inputDefs.push_back({INPUT_DATA, "layer_0", input_dim, 0}); + config.inputDefs.push_back({INPUT_DATA, "layer_1", input_dim, 0}); + config.layerConfig.add_inputs(); + config.layerConfig.add_inputs(); + + for (auto useGpu : {false, true}) { + testLayerGrad(config, "l2_distance", batch_size, false, useGpu); + } +} + int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); initMain(argc, argv); diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 46c2833030c936119e98adcdd338245bbdaddce7..d0fe5b4635174fa0f74658509c4e8ca58a1d056a 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -87,6 +87,11 @@ function(op_library TARGET) file(APPEND ${pybind_file} "USE_OP(pool2d_cudnn);\n") endif() + if ("${TARGET}" STREQUAL "logical_op") + set(pybind_flag 1) + file(APPEND ${pybind_file} "USE_OP(logical_and);\n") + endif() + # pool_with_index_op contains several operators if ("${TARGET}" STREQUAL "pool_with_index_op") set(pybind_flag 1) diff --git a/paddle/operators/logical_op.cc b/paddle/operators/logical_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..a37582c1d840ac11f847d8743c824ef1aef0fd66 --- /dev/null +++ b/paddle/operators/logical_op.cc @@ -0,0 +1,153 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/logical_op.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { +template +class BinaryLogicalOpProtoMaker : public framework::OpProtoAndCheckerMaker { + public: + BinaryLogicalOpProtoMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + OpComment comment; + AddInput("X", + string::Sprintf("(LoDTensor) Left hand operand of %s operator", + comment.type)); + AddInput("Y", + string::Sprintf("(LoDTensor) Right hand operand of %s operator", + comment.type)); + AddOutput("Out", string::Sprintf( + "(LoDTensor) n-dim bool tensor. Each element is %s", + comment.equation)); + AddComment(string::Sprintf(R"DOC(%s Operator + +It operates element-wise on X and Y, and returns the Out. X, Y and Out are N-dim boolean tensors. +Each element of Out is calculated by %s +)DOC", + comment.type, comment.equation)); + } +}; + +template +class UnaryLogicalOpProtoMaker : public framework::OpProtoAndCheckerMaker { + public: + UnaryLogicalOpProtoMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + OpComment comment; + AddInput("X", string::Sprintf("(LoDTensor) Operand of %s operator", + comment.type)); + AddOutput("Out", string::Sprintf( + "(LoDTensor) n-dim bool tensor. Each element is %s", + comment.equation)); + AddComment(string::Sprintf(R"DOC(%s Operator + +It operates element-wise on X, and returns the Out. X and Out are N-dim boolean tensors. +Each element of Out is calculated by %s +)DOC", + comment.type, comment.equation)); + } +}; + +template +class BinaryLogicalOpInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *context) const override { + OpComment comment; + PADDLE_ENFORCE(context->HasInput("X"), + "Input(X) of %s operator must not be null", comment.type); + PADDLE_ENFORCE(context->HasInput("Y"), + "Input(Y) of %s operator must not be null", comment.type); + auto dim_x = context->GetInputDim("X"); + auto dim_y = context->GetInputDim("Y"); + PADDLE_ENFORCE_EQ(framework::product(dim_x), framework::product(dim_y), + "The number of elements in X and Y should be same"); + + context->SetOutputDim("Out", context->GetInputDim("X")); + context->ShareLoD("X", "Out"); + } +}; + +template +class UnaryLogicalOpInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *context) const override { + OpComment comment; + PADDLE_ENFORCE(context->HasInput("X"), + "Input(X) of %s operator must not be null", comment.type); + auto dim_x = context->GetInputDim("X"); + + context->SetOutputDim("Out", context->GetInputDim("X")); + context->ShareLoD("X", "Out"); + } +}; + +class LogicalOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + framework::OpKernelType GetKernelType( + const framework::ExecutionContext &ctx) const override { + framework::OpKernelType kt = OperatorWithKernel::GetKernelType(ctx); + // LogicalOp kernel's device type is decided by input tensor place + kt.place_ = ctx.Input("X")->place(); + return kt; + } +}; + +} // namespace operators +} // namespace paddle + +#define REGISTER_BINARY_LOGICAL_OP(op_type, _equation) \ + struct _##op_type##Comment { \ + static char type[]; \ + static char equation[]; \ + }; \ + char _##op_type##Comment::type[]{#op_type}; \ + char _##op_type##Comment::equation[]{_equation}; \ + REGISTER_OPERATOR( \ + op_type, ::paddle::operators::LogicalOp, \ + ::paddle::operators::BinaryLogicalOpProtoMaker<_##op_type##Comment>, \ + ::paddle::operators::BinaryLogicalOpInferShape<_##op_type##Comment>, \ + ::paddle::framework::EmptyGradOpMaker); + +#define REGISTER_UNARY_LOGICAL_OP(op_type, _equation) \ + struct _##op_type##Comment { \ + static char type[]; \ + static char equation[]; \ + }; \ + char _##op_type##Comment::type[]{#op_type}; \ + char _##op_type##Comment::equation[]{_equation}; \ + REGISTER_OPERATOR( \ + op_type, ::paddle::operators::LogicalOp, \ + ::paddle::operators::UnaryLogicalOpProtoMaker<_##op_type##Comment>, \ + ::paddle::operators::UnaryLogicalOpInferShape<_##op_type##Comment>, \ + ::paddle::framework::EmptyGradOpMaker); + +REGISTER_BINARY_LOGICAL_OP(logical_and, "Out = X && Y"); +REGISTER_BINARY_LOGICAL_KERNEL(logical_and, CPU, + paddle::operators::LogicalAndFunctor); +REGISTER_BINARY_LOGICAL_OP(logical_or, "Out = X && Y"); +REGISTER_BINARY_LOGICAL_KERNEL(logical_or, CPU, + paddle::operators::LogicalOrFunctor); +REGISTER_UNARY_LOGICAL_OP(logical_not, "Out = !X"); +REGISTER_UNARY_LOGICAL_KERNEL(logical_not, CPU, + paddle::operators::LogicalNotFunctor); +REGISTER_BINARY_LOGICAL_OP(logical_xor, "Out = (X || Y) && !(X && Y)"); +REGISTER_BINARY_LOGICAL_KERNEL(logical_xor, CPU, + paddle::operators::LogicalXorFunctor); diff --git a/paddle/operators/logical_op.cu b/paddle/operators/logical_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..d41239b2ca43e7145ea56afcb0af69948838cc48 --- /dev/null +++ b/paddle/operators/logical_op.cu @@ -0,0 +1,24 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/logical_op.h" + +REGISTER_BINARY_LOGICAL_KERNEL(logical_and, GPU, + paddle::operators::LogicalAndFunctor); +REGISTER_BINARY_LOGICAL_KERNEL(logical_or, GPU, + paddle::operators::LogicalOrFunctor); +REGISTER_UNARY_LOGICAL_KERNEL(logical_not, GPU, + paddle::operators::LogicalNotFunctor); +REGISTER_BINARY_LOGICAL_KERNEL(logical_xor, GPU, + paddle::operators::LogicalXorFunctor); diff --git a/paddle/operators/logical_op.h b/paddle/operators/logical_op.h new file mode 100644 index 0000000000000000000000000000000000000000..6e78a7d6ed87ba950886e6bc667f82118ff78904 --- /dev/null +++ b/paddle/operators/logical_op.h @@ -0,0 +1,93 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include +#include +#include "paddle/framework/op_registry.h" +#include "paddle/platform/transform.h" + +namespace paddle { +namespace operators { + +template +struct LogicalAndFunctor { + using ELEM_TYPE = T; + HOSTDEVICE bool operator()(const T& a, const T& b) const { return a && b; } +}; + +template +struct LogicalOrFunctor { + using ELEM_TYPE = T; + HOSTDEVICE bool operator()(const T& a, const T& b) const { return a || b; } +}; + +template +struct LogicalNotFunctor { + using ELEM_TYPE = T; + HOSTDEVICE bool operator()(const T& a) const { return !a; } +}; + +template +struct LogicalXorFunctor { + using ELEM_TYPE = T; + HOSTDEVICE bool operator()(const T& a, const T& b) const { + return (a || b) && !(a && b); + } +}; + +template +class BinaryLogicalOpKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + using T = typename Functor::ELEM_TYPE; + auto* x = context.Input("X"); + auto* y = context.Input("Y"); + auto* out = context.Output("Out"); + Functor binary_func; + platform::Transform trans; + trans(context.device_context(), x->data(), x->data() + x->numel(), + y->data(), out->mutable_data(context.GetPlace()), + binary_func); + } +}; + +template +class UnaryLogicalOpKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + using T = typename Functor::ELEM_TYPE; + auto* x = context.Input("X"); + auto* out = context.Output("Out"); + Functor unary_func; + platform::Transform trans; + trans(context.device_context(), x->data(), x->data() + x->numel(), + out->mutable_data(context.GetPlace()), unary_func); + } +}; + +} // namespace operators +} // namespace paddle + +#define REGISTER_BINARY_LOGICAL_KERNEL(op_type, dev, functor) \ + REGISTER_OP_##dev##_KERNEL( \ + op_type, ::paddle::operators::BinaryLogicalOpKernel< \ + ::paddle::platform::dev##Place, functor>); + +#define REGISTER_UNARY_LOGICAL_KERNEL(op_type, dev, functor) \ + REGISTER_OP_##dev##_KERNEL( \ + op_type, ::paddle::operators::UnaryLogicalOpKernel< \ + ::paddle::platform::dev##Place, functor>); diff --git a/paddle/operators/sequence_slice_op.h b/paddle/operators/sequence_slice_op.h index 2ef2c8f0c4f3061c0082916edc18564698821bfc..dc3ee52ecdb1003496b427bf80dc26316e7a911a 100755 --- a/paddle/operators/sequence_slice_op.h +++ b/paddle/operators/sequence_slice_op.h @@ -77,13 +77,13 @@ class SequenceSliceOpKernel : public framework::OpKernel { for (size_t i = 0; i < n; ++i) { PADDLE_ENFORCE_LT(0, offset_data[i], - "The offset must greater than zero") + "The offset[%d] must greater than zero.") PADDLE_ENFORCE_LT(0, length_data[i], - "The length must greater than zero") + "The length[%d] must greater than zero.") PADDLE_ENFORCE_LT( lod[0][i] + offset_data[i] + length_data[i], lod[0][i + 1], - "The target tensor's length overflow") + "The target tensor's length overflow.") } out->mutable_data(ctx.GetPlace()); diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index d6128dd7692a2faebf453d239744c4893d84e369..0b523ac7e0bf5231398778ea69270c883ac112d2 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -3342,6 +3342,20 @@ class RowL2NormLayer(LayerBase): self.set_layer_size(input_layer.size) +@config_layer('cos') +class CosSimLayer(LayerBase): + def __init__(self, name, inputs, cos_scale=1, device=None): + super(CosSimLayer, self).__init__( + name, 'cos', 1, inputs=inputs, device=device) + config_assert( + len(self.inputs) == 2, + 'The CosSimLayer expects two and only two inputs.') + config_assert( + self.get_input_layer(0).size == self.get_input_layer(1).size, + 'The two inputs of CosSimLayer must have the same dimensionality.') + self.config.cos_scale = cos_scale + + @config_layer('cos_vm') class CosSimVecMatLayer(LayerBase): def __init__(self, name, size, inputs, cos_scale=1.0, device=None): @@ -3349,10 +3363,24 @@ class CosSimVecMatLayer(LayerBase): name, 'cos_vm', size, inputs=inputs, device=device) self.config.cos_scale = cos_scale config_assert( - len(self.inputs) == 2, 'CosSimVecMatLayer must have 2 inputs') + len(self.inputs) == 2, 'The CosSimVecMatLayer must have 2 inputs.') config_assert( size * self.get_input_layer(0).size == self.get_input_layer(1).size, - 'Wrong input size for CosSimVecMatLayer') + 'Wrong input size for CosSimVecMatLayer.') + + +@config_layer('l2_distance') +class L2DistanceLayer(LayerBase): + def __init__(self, name, inputs, device=None): + super(L2DistanceLayer, self).__init__( + name, 'l2_distance', 1, inputs=inputs, device=device) + config_assert( + len(self.inputs) == 2, ('The L2DistanceLayer must have ' + 'and only have 2 inputs.')) + config_assert( + self.get_input_layer(0).size == self.get_input_layer(1).size, + ('Two inputs of the L2DistanceLayer must have ' + 'the same dimensionality.')) @config_layer('sampling_id') @@ -3396,18 +3424,6 @@ class AverageLayer(LayerBase): self.create_bias_parameter(bias, self.config.size) -@config_layer('cos') -class CosSimLayer(LayerBase): - def __init__(self, name, inputs, cos_scale=1, device=None): - super(CosSimLayer, self).__init__( - name, 'cos', 1, inputs=inputs, device=device) - config_assert(len(self.inputs) == 2, 'CosSimLayer must have 2 inputs') - config_assert( - self.get_input_layer(0).size == self.get_input_layer(1).size, - 'inputs of CosSimLayer must have same dim') - self.config.cos_scale = cos_scale - - @config_layer('tensor') class TensorLayer(LayerBase): def __init__(self, name, size, inputs, bias=True, **xargs): diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index 388535d53a9d1d6747ac89cb698f3a1f496b5f7c..14cdee4c5564f7a1cc4ff7a19f4f7ac02b08f21c 100644 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -51,6 +51,7 @@ __all__ = [ 'last_seq', 'first_seq', 'cos_sim', + 'l2_distance_layer', 'hsigmoid', 'conv_projection', 'square_error_cost', @@ -168,6 +169,7 @@ class LayerType(object): COST = 'cost' COSINE_SIM_VEC = 'cos_vm' COSINE_SIM = 'cos' + L2_DISTANCE = 'l2_distance' HSIGMOID = 'hsigmoid' CONV_LAYER = 'conv' CONVTRANS_LAYER = 'convt' @@ -2334,6 +2336,51 @@ def cos_sim(a, b, scale=1, size=1, name=None, layer_attr=None): return LayerOutput(name, LayerType.COSINE_SIM, parents=[a, b], size=size) +@wrap_name_default() +@layer_support() +def l2_distance_layer(x, y, name=None, layer_attr=None): + """ + This layer calculates and returns the Euclidean distance between two input + vectors x and y. The equation is as follows: + + .. math:: + l2_distance(\\mathbf{x}, \\mathbf{y}) = \\sqrt{\\sum_{i=1}^D(x_i - y_i)} + + The output size of this layer is fixed to be 1. Note that the above + computation is for one sample. Multiple samples are processed in one batch. + + The example usage is: + + .. code-block:: python + + l2_sim = l2_distance(x=layer1, y=layer2) + + :param name: The name of this layer. It is optional. + :type name: basestring + :param x: The first input x for this layer, whose output is a matrix with + dimensionality N x D. N is the sample number in a mini-batch. + D is the dimensionality of x's output. + :type x: LayerOutput + :param y: The second input y for this layer, whose output is a matrix with + dimensionality N x D. N is the sample number in a mini-batch. + D is the dimensionality of y's output. + :type y: LayerOutput + :param layer_attr: The extra layer attributes, for example, drop rate. + See ExtraLayerAttribute for more details. + :type layer_attr: ExtraLayerAttribute + :return: The returned LayerOutput object. + :rtype: LayerOutput + """ + + assert isinstance(x, LayerOutput) and isinstance(y, LayerOutput) + Layer( + name=name, + type=LayerType.L2_DISTANCE, + inputs=[x.name, y.name], + **ExtraLayerAttribute.to_kwargs(layer_attr)) + return LayerOutput(name, LayerType.L2_DISTANCE, parents=[x, y], size=1) + + @wrap_name_default() @wrap_bias_attr_default(has_bias=True) @wrap_param_attr_default() @@ -3873,7 +3920,7 @@ def recurrent_layer(input, :type input: LayerOutput :param act: Activation type. TanhActivation is the default activation. :type act: BaseActivation - :param bias_attr: The parameter attribute for bias. If this parameter is set to + :param bias_attr: The parameter attribute for bias. If this parameter is set to False or an object whose type is not ParameterAttribute, no bias is defined. If the parameter is set to True, the bias is initialized to zero. diff --git a/python/paddle/trainer_config_helpers/tests/configs/file_list.sh b/python/paddle/trainer_config_helpers/tests/configs/file_list.sh index 0b269a1ff76530774b4d23b0867350fd95e081a3..6c09ca3d346e36b6f38ac5c89914e95a7383f719 100755 --- a/python/paddle/trainer_config_helpers/tests/configs/file_list.sh +++ b/python/paddle/trainer_config_helpers/tests/configs/file_list.sh @@ -10,7 +10,8 @@ test_prelu_layer test_row_conv test_detection_output_layer test_multibox_loss_la test_recursive_topology test_gated_unit_layer test_clip_layer test_row_l2_norm_layer test_kmax_seq_socre_layer test_sub_nested_seq_select_layer test_scale_shift_layer test_seq_slice_layer test_cross_entropy_over_beam test_roi_pool_layer test_pooling3D_layer -test_conv3d_layer test_deconv3d_layer test_BatchNorm3D test_resize_layer test_scale_sub_region_layer -test_dot_prod_layer) +test_conv3d_layer test_deconv3d_layer test_BatchNorm3D test_resize_layer +test_conv3d_layer test_deconv3d_layer test_BatchNorm3D test_resize_layer +test_scale_sub_region_layer test_dot_prod_layer test_l2_distance_layer) export whole_configs=(test_split_datasource) diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_l2_distance_layer.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_l2_distance_layer.protostr new file mode 100644 index 0000000000000000000000000000000000000000..9ba33689edc893c2169a73679a04a6f51cfc83a8 --- /dev/null +++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_l2_distance_layer.protostr @@ -0,0 +1,39 @@ +type: "nn" +layers { + name: "x" + type: "data" + size: 128 + active_type: "" +} +layers { + name: "y" + type: "data" + size: 128 + active_type: "" +} +layers { + name: "__l2_distance_layer_0__" + type: "l2_distance" + size: 1 + active_type: "" + inputs { + input_layer_name: "x" + } + inputs { + input_layer_name: "y" + } +} +input_layer_names: "x" +input_layer_names: "y" +output_layer_names: "__l2_distance_layer_0__" +sub_models { + name: "root" + layer_names: "x" + layer_names: "y" + layer_names: "__l2_distance_layer_0__" + input_layer_names: "x" + input_layer_names: "y" + output_layer_names: "__l2_distance_layer_0__" + is_recurrent_layer_group: false +} + diff --git a/python/paddle/trainer_config_helpers/tests/configs/test_l2_distance_layer.py b/python/paddle/trainer_config_helpers/tests/configs/test_l2_distance_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..b36a5c6d1222860ee4b77f89ad4b6148ccd89589 --- /dev/null +++ b/python/paddle/trainer_config_helpers/tests/configs/test_l2_distance_layer.py @@ -0,0 +1,7 @@ +from paddle.trainer_config_helpers import * + +outputs( + l2_distance_layer( + x=data_layer( + name='x', size=128), y=data_layer( + name='y', size=128))) diff --git a/python/paddle/v2/fluid/tests/test_logical_op.py b/python/paddle/v2/fluid/tests/test_logical_op.py new file mode 100644 index 0000000000000000000000000000000000000000..ac90bf839cb96053387bb82c112692136707744c --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_logical_op.py @@ -0,0 +1,35 @@ +import op_test +import unittest +import numpy as np + + +def create_test_class(op_type, callback, binary_op=True): + class Cls(op_test.OpTest): + def setUp(self): + a = np.random.choice(a=[True, False], size=(10, 7)).astype(bool) + if binary_op: + b = np.random.choice(a=[True, False], size=(10, 7)).astype(bool) + c = callback(a, b) + else: + c = callback(a) + self.outputs = {'Out': c} + self.op_type = op_type + if binary_op: + self.inputs = {'X': a, 'Y': b} + else: + self.inputs = {'X': a} + + def test_output(self): + self.check_output() + + Cls.__name__ = op_type + globals()[op_type] = Cls + + +create_test_class('logical_and', lambda _a, _b: np.logical_and(_a, _b)) +create_test_class('logical_or', lambda _a, _b: np.logical_or(_a, _b)) +create_test_class('logical_not', lambda _a: np.logical_not(_a), False) +create_test_class('logical_xor', lambda _a, _b: np.logical_xor(_a, _b)) + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/fluid/tests/test_optimizer.py b/python/paddle/v2/fluid/tests/test_optimizer.py index 7b4237e7fdf5990019ddd85967036ceb598c33df..2459dfd664300d405edb36c4ca906c1769b5e7d2 100644 --- a/python/paddle/v2/fluid/tests/test_optimizer.py +++ b/python/paddle/v2/fluid/tests/test_optimizer.py @@ -16,14 +16,18 @@ class TestOptimizer(unittest.TestCase): dtype="float32", shape=[10, 8], lod_level=0, name="mul.y") mul_out = block.create_var( dtype="float32", shape=[5, 8], lod_level=0, name="mul.out") + mean_out = block.create_var( + dtype="float32", shape=[1], lod_level=0, name="mean.out") block.append_op( type="mul", inputs={"X": mul_x, "Y": mul_y}, outputs={"Out": mul_out}, attrs={"x_num_col_dims": 1}) + block.append_op( + type="mean", inputs={"X": mul_out}, outputs={"Out": mean_out}) sgd_optimizer = optimizer.SGDOptimizer(learning_rate=0.01) - opts = sgd_optimizer.minimize(mul_out, init_program) + opts = sgd_optimizer.minimize(mean_out, init_program) self.assertEqual(len(opts), 1) sgd_op = opts[0] self.assertEqual(sgd_op.type, "sgd") @@ -44,12 +48,16 @@ class TestOptimizer(unittest.TestCase): "Y": mul_y}, outputs={"Out": mul_out}, attrs={"x_num_col_dims": 1}) + mean_out = block.create_var( + dtype="float32", shape=[1], lod_level=0, name="mean.out") + block.append_op( + type="mean", inputs={"X": mul_out}, outputs={"Out": mean_out}) global_step = block.create_var( dtype="float32", shape=[1], lod_level=0, name="step") learning_rate = 0.01 sgd_optimizer = optimizer.SGDOptimizer( learning_rate=learning_rate, global_step=global_step) - opts = sgd_optimizer.minimize(mul_out, init_program) + opts = sgd_optimizer.minimize(mean_out, init_program) self.assertEqual(len(opts), 2) sgd_op = opts[0] self.assertEqual(sgd_op.type, "sgd") @@ -90,7 +98,11 @@ class TestMomentumOptimizer(unittest.TestCase): learning_rate = 0.01 momentum_optimizer = self.MockMomentum( learning_rate=learning_rate, momentum=0.2) - params_grads = append_backward_ops(mul_out) + mean_out = block.create_var( + dtype="float32", shape=[1], lod_level=0, name="mean.out") + block.append_op( + type="mean", inputs={"X": mul_out}, outputs={"Out": mean_out}) + params_grads = append_backward_ops(mean_out) self.assertEqual(len(params_grads), 1) self.assertEqual(len(momentum_optimizer.get_accumulators()), 0) opts = momentum_optimizer.create_optimization_pass( @@ -132,10 +144,14 @@ class TestMomentumOptimizer(unittest.TestCase): "Y": mul_y}, outputs={"Out": mul_out}, attrs={"x_num_col_dims": 1}) + mean_out = block.create_var( + dtype="float32", shape=[1], lod_level=0, name="mean.out") + block.append_op( + type="mean", inputs={"X": mul_out}, outputs={"Out": mean_out}) learning_rate = 0.01 momentum_optimizer = self.MockMomentum( learning_rate=learning_rate, momentum=0.2, use_nesterov=True) - params_grads = append_backward_ops(mul_out) + params_grads = append_backward_ops(mean_out) self.assertEqual(len(params_grads), 1) self.assertEqual(len(momentum_optimizer.get_accumulators()), 0) opts = momentum_optimizer.create_optimization_pass( @@ -186,10 +202,14 @@ class TestAdagradOptimizer(unittest.TestCase): "Y": mul_y}, outputs={"Out": mul_out}, attrs={"x_num_col_dims": 1}) + mean_out = block.create_var( + dtype="float32", shape=[1], lod_level=0, name="mean.out") + block.append_op( + type="mean", inputs={"X": mul_out}, outputs={"Out": mean_out}) learning_rate = 0.01 adagrad_optimizer = self.MockAdagrad( learning_rate=learning_rate, epsilon=1.0e-6) - params_grads = append_backward_ops(mul_out) + params_grads = append_backward_ops(mean_out) self.assertEqual(len(params_grads), 1) self.assertEqual(len(adagrad_optimizer.get_accumulators()), 0) opts = adagrad_optimizer.create_optimization_pass(params_grads, mul_out, @@ -242,10 +262,14 @@ class TestAdamOptimizer(unittest.TestCase): "Y": mul_y}, outputs={"Out": mul_out}, attrs={"x_num_col_dims": 1}) + mean_out = block.create_var( + dtype="float32", shape=[1], lod_level=0, name="mean.out") + block.append_op( + type="mean", inputs={"X": mul_out}, outputs={"Out": mean_out}) learning_rate = 0.01 adam_optimizer = self.MockAdam( learning_rate=learning_rate, beta1=0.9, beta2=0.999) - params_grads = append_backward_ops(mul_out) + params_grads = append_backward_ops(mean_out) self.assertEqual(len(params_grads), 1) self.assertEqual(len(adam_optimizer.get_accumulators()), 0) opts = adam_optimizer.create_optimization_pass(params_grads, mul_out, @@ -300,10 +324,14 @@ class TestAdamaxOptimizer(unittest.TestCase): "Y": mul_y}, outputs={"Out": mul_out}, attrs={"x_num_col_dims": 1}) + mean_out = block.create_var( + dtype="float32", shape=[1], lod_level=0, name="mean.out") + block.append_op( + type="mean", inputs={"X": mul_out}, outputs={"Out": mean_out}) learning_rate = 0.01 adamax_optimizer = self.MockAdamax( learning_rate=learning_rate, beta1=0.9, beta2=0.999) - params_grads = append_backward_ops(mul_out) + params_grads = append_backward_ops(mean_out) self.assertEqual(len(params_grads), 1) self.assertEqual(len(adamax_optimizer.get_accumulators()), 0) opts = adamax_optimizer.create_optimization_pass(params_grads, mul_out, @@ -355,10 +383,14 @@ class TestDecayedAdagradOptimizer(unittest.TestCase): "Y": mul_y}, outputs={"Out": mul_out}, attrs={"x_num_col_dims": 1}) + mean_out = block.create_var( + dtype="float32", shape=[1], lod_level=0, name="mean.out") + block.append_op( + type="mean", inputs={"X": mul_out}, outputs={"Out": mean_out}) learning_rate = 0.01 decayed_adagrad_optimizer = self.MockDecayedAdagrad( learning_rate=learning_rate, decay=0.95, epsilon=1.0e-6) - params_grads = append_backward_ops(mul_out) + params_grads = append_backward_ops(mean_out) self.assertEqual(len(params_grads), 1) self.assertEqual(len(decayed_adagrad_optimizer.get_accumulators()), 0) opts = decayed_adagrad_optimizer.create_optimization_pass( diff --git a/python/paddle/v2/fluid/tests/test_program.py b/python/paddle/v2/fluid/tests/test_program.py index ef2daf6916e14c015a39ae0193948e7ff6531449..e9bcefd21569aaa9225c676ea03b5c8e37d00333 100644 --- a/python/paddle/v2/fluid/tests/test_program.py +++ b/python/paddle/v2/fluid/tests/test_program.py @@ -1,6 +1,5 @@ import unittest -import paddle.v2.fluid.core as core from paddle.v2.fluid.framework import Program from paddle.v2.fluid.framework import g_main_program @@ -98,21 +97,26 @@ class TestProgram(unittest.TestCase): "Y": add_y}, outputs={"Out": add_out}, attrs={"x_num_col_dims": 1}) + mean_out = block.create_var( + dtype="float32", shape=[1], lod_level=0, name="mean.out") + block.append_op( + type="mean", inputs={"X": add_out}, outputs={"Out": mean_out}) self.assertEqual(mul_op.idx, 0) self.assertEqual(add_op.idx, 1) - param_to_grad = prog.append_backward(add_out, set()) + param_to_grad = prog.append_backward(mean_out, set()) def grad_name(name): return name + "@GRAD" - for var_name in ("mul.x", "mul.y", "mul.out", "add.y", "add.out"): + for var_name in ("mul.x", "mul.y", "mul.out", "add.y", "add.out", + "mean.out"): self.assertEqual(param_to_grad[var_name][0], grad_name(var_name)) self.assertEqual(param_to_grad[var_name][1], 0) expect_ops = [ - "mul", "elementwise_add", "fill_constant", "elementwise_add_grad", - "mul_grad" + "mul", "elementwise_add", "mean", "fill_constant", "mean_grad", + "elementwise_add_grad", "mul_grad" ] actual_ops = [] for op in block.ops: diff --git a/python/paddle/v2/fluid/tests/test_regularizer.py b/python/paddle/v2/fluid/tests/test_regularizer.py index f5d1eb3b96211bd7c7335dbe116a1d765d7bae50..24baf55e90c98f39bab926e8c85a791eee5ed4a4 100644 --- a/python/paddle/v2/fluid/tests/test_regularizer.py +++ b/python/paddle/v2/fluid/tests/test_regularizer.py @@ -29,7 +29,11 @@ class TestL2DecayRegularizer(unittest.TestCase): "Y": mul_y}, outputs={"Out": mul_out}, attrs={"x_num_col_dims": 1}) - params_grads = append_backward_ops(mul_out) + mean_out = block.create_var( + dtype="float32", shape=[1], lod_level=0, name="mean.out") + block.append_op( + type="mean", inputs={"X": mul_out}, outputs={"Out": mean_out}) + params_grads = append_backward_ops(mean_out) self.assertEqual(len(params_grads), 1) count_ops = len(block.ops) params_grads = optimizer.append_regularization_ops(params_grads) @@ -62,7 +66,11 @@ class TestL1DecayRegularizer(unittest.TestCase): "Y": mul_y}, outputs={"Out": mul_out}, attrs={"x_num_col_dims": 1}) - params_grads = append_backward_ops(mul_out) + mean_out = block.create_var( + dtype="float32", shape=[1], lod_level=0, name="mean.out") + block.append_op( + type="mean", inputs={"X": mul_out}, outputs={"Out": mean_out}) + params_grads = append_backward_ops(mean_out) self.assertEqual(len(params_grads), 1) count_ops = len(block.ops) params_grads = optimizer.append_regularization_ops(params_grads)