提交 d7e32ab3 编写于 作者: T tensor-tang

Merge remote-tracking branch 'upstream/develop' into merge

IfOp should have only one branch. An IfOp operator takes a `cond` variable whose value must be a vector of N boolean elements. Its return value has M (M<=N) instances, each corresponds to a true element in `cond`.
```python
import paddle as pd
x = var()
y = var()
cond = var()
b = pd.create_ifop(inputs=[x], output_num=1)
with b.true_block():
x = b.inputs(0)
z = operator.add(x, y)
b.set_output(0, operator.softmax(z))
out = b(cond)
```
If we want the output still has N instances, we can use IfElseOp with a default value, whose minibatch size must be N:
IfOp should have only one branch. An IfOp operator takes a `cond` variable whose value must be a vector of N boolean elements. Its return value has N instances. If cond[i] == True, input instance input[i] will go through true_block() and generate output[i]; otherwise it will produce output from false_bloack().
```python
import paddle as pd
......@@ -39,7 +21,7 @@ with b.false_block():
out = b(cond)
```
If only true_block is set in an IfElseOp, we can have a default value for false as:
If only true_block is set in an IfElseOp, a special case is that we can have a default value for false as:
```python
import paddle as pd
......
......@@ -22,7 +22,7 @@ namespace framework {
template <typename T>
inline void Tensor::check_memory_size() const {
PADDLE_ENFORCE_NOT_NULL(
holder_, "Tenosr holds no memory. Call Tensor::mutable_data first.");
holder_, "Tensor holds no memory. Call Tensor::mutable_data first.");
PADDLE_ENFORCE_GE(
holder_->size(), numel() * sizeof(T) + offset_,
"Tensor's dims_ is out of bound. Call Tensor::mutable_data "
......
......@@ -36,7 +36,7 @@ TEST(Tensor, DataAssert) {
} catch (paddle::platform::EnforceNotMet err) {
caught = true;
std::string msg =
"holder_ should not be null\nTenosr holds no memory. Call "
"holder_ should not be null\nTensor holds no memory. Call "
"Tensor::mutable_data first.";
const char* what = err.what();
for (size_t i = 0; i < msg.length(); ++i) {
......@@ -112,7 +112,7 @@ TEST(Tensor, ShareDataWith) {
} catch (paddle::platform::EnforceNotMet err) {
caught = true;
std::string msg =
"holder_ should not be null\nTenosr holds no memory. Call "
"holder_ should not be null\nTensor holds no memory. Call "
"Tensor::mutable_data first.";
const char* what = err.what();
for (size_t i = 0; i < msg.length(); ++i) {
......@@ -274,4 +274,4 @@ TEST(Tensor, ReshapeToMatrix) {
Tensor res = ReshapeToMatrix<int>(src, 2);
ASSERT_EQ(res.dims()[0], 2 * 3);
ASSERT_EQ(res.dims()[1], 4 * 9);
}
\ No newline at end of file
}
......@@ -62,6 +62,24 @@ void Copy<platform::GPUPlace, platform::GPUPlace>(platform::GPUPlace dst_place,
}
}
template <>
void Copy<platform::CPUPlace, platform::GPUPlace>(platform::CPUPlace dst_place,
void* dst,
platform::GPUPlace src_place,
const void* src, size_t num) {
platform::SetDeviceId(src_place.device);
platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToHost);
}
template <>
void Copy<platform::GPUPlace, platform::CPUPlace>(platform::GPUPlace dst_place,
void* dst,
platform::CPUPlace src_place,
const void* src, size_t num) {
platform::SetDeviceId(dst_place.device);
platform::GpuMemcpySync(dst, src, num, cudaMemcpyHostToDevice);
}
#endif // PADDLE_ONLY_CPU
} // namespace memory
......
......@@ -80,9 +80,11 @@ endfunction()
add_subdirectory(math)
set(DEPS_OPS
recurrent_op)
recurrent_op
cond_op)
op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc
DEPS framework_proto tensor net_op)
op_library(cond_op SRCS cond_op.cc DEPS framework_proto tensor operator net_op)
list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS})
foreach(src ${GENERAL_OPS})
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/cond_op.h"
#include <cstring>
#include <sstream>
#include "paddle/framework/op_registry.h"
#include "paddle/operators/gather.h"
#include "paddle/operators/net_op.h"
#include "paddle/operators/scatter.h"
namespace paddle {
namespace operators {
using Scope = framework::Scope;
using Variable = framework::Variable;
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using DDim = framework::DDim;
void CondOp::CreateScope(const Scope& scope) const {
auto sub_scopes_var = scope.FindVar("SubScopes");
PADDLE_ENFORCE(sub_scopes_var != nullptr, "");
auto sub_scopes = sub_scopes_var->GetMutable<std::vector<Scope*>>();
auto& sub_scope = scope.NewScope();
sub_scopes->push_back(&sub_scope);
}
void CondOp::CreateIndexTensor(const Scope& scope) const {
auto index_tensors_var = scope.FindVar("IndexTensors");
PADDLE_ENFORCE(index_tensors_var != nullptr, "");
auto& index_tensors =
*index_tensors_var->GetMutable<std::vector<LoDTensor>>();
index_tensors.push_back(LoDTensor());
}
void CondOp::InferShape(const Scope& scope) const {
auto sub_scopes_var = scope.FindVar("SubScopes");
PADDLE_ENFORCE_NOT_NULL(sub_scopes_var);
auto& sub_scopes = *sub_scopes_var->GetMutable<std::vector<Scope*>>();
for (int i = 0; i < 2; ++i) {
// Create two sub scopes for true and false branches
// sub_scopes[0] for the true branch and sub_scopes[1] for the false
// branch
CreateScope(scope);
// Create two tensors for true and false indices
// index_tensors[0] for the true branch and index_tensors[1] for the false
// branch
CreateIndexTensor(scope);
PADDLE_ENFORCE(!Inputs("Xs").empty(), "Inputs can't be empty");
for (auto& input : Inputs("Xs")) {
// Create a new tensor in sub-scope for input-type tensor
Variable* v = sub_scopes[i]->NewVar(input);
LoDTensor* sub_input = v->GetMutable<LoDTensor>();
sub_input->Resize(scope.FindVar(input)->GetMutable<LoDTensor>()->dims());
}
for (auto& output : (*sub_net_op_[i]).Outputs()) {
for (auto& var_name : output.second) {
sub_scopes[i]->NewVar(var_name);
}
}
// each net calls InferShape
sub_net_op_[i]->InferShape(*sub_scopes[i]);
}
for (auto& output : Outputs("Outs")) {
LoDTensor* tensor_t_out =
sub_scopes[0]->FindVar(output)->GetMutable<LoDTensor>();
PADDLE_ENFORCE_NOT_NULL(tensor_t_out, "True output should not be NULL");
LoDTensor* tensor_f_out =
sub_scopes[1]->FindVar(output)->GetMutable<LoDTensor>();
PADDLE_ENFORCE_NOT_NULL(tensor_f_out, "False output should not be NULL");
auto* tensor_out_var = scope.FindVar(output);
PADDLE_ENFORCE_NOT_NULL(tensor_out_var, "Output not found");
LoDTensor* tensor_out = tensor_out_var->GetMutable<LoDTensor>();
PADDLE_ENFORCE_NOT_NULL(tensor_t_out,
"True output tensor should not be NULL");
// check output size should be same
PADDLE_ENFORCE_EQ(tensor_t_out->dims(), tensor_f_out->dims(),
"Outputs not of the same shape");
tensor_out->Resize(tensor_t_out->dims());
// tensor_out->mutable_data<float>(tensor_out->dims(),
// platform::CPUPlace());
tensor_out->mutable_data<float>(platform::CPUPlace());
}
}
void CondOp::Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const {
auto* sub_scopes_var = scope.FindVar("SubScopes");
auto sub_scopes = sub_scopes_var->Get<std::vector<Scope*>>();
auto* index_tensors_var = scope.FindVar("IndexTensors");
auto index_tensors = index_tensors_var->Get<std::vector<LoDTensor>>();
std::string cond_name = Input("Cond");
Variable* cond_var = scope.FindVar(cond_name);
PADDLE_ENFORCE_NOT_NULL(cond_var);
const LoDTensor* cond = cond_var->GetMutable<LoDTensor>();
// Step 1: get the true/false index at runtime
// index_[0]: vector<int>, contains all index for cond[i] == true
// index_[1]: vector<int>, contains all index for cond[i] == false
for (int i = 0; i < 2; ++i) index_[i].clear();
const int* cond_data = cond->data<int>();
for (int i = 0; i < cond->dims()[0]; ++i) {
if (cond_data[i])
index_[0].push_back(i);
else
index_[1].push_back(i);
}
// put index_[0] and index_[1] into two tensors:
// index_tensor_[0] and index_tensor_[1]
DDim dim = paddle::framework::make_ddim({0});
for (int i = 0; i < 2; ++i) {
dim[0] = index_[i].size();
int* tmp_ptr =
index_tensors[i].mutable_data<int>(dim, platform::CPUPlace());
index_tensors[i].Resize(dim);
memcpy(tmp_ptr, index_[i].data(), dim[0] * sizeof(int));
}
// Step 2: collect data by calling gather
for (int i = 0; i < 2; ++i) {
// i= 0/i for True and False branches respectively
for (auto& input : Inputs("Xs")) {
// find Tensor
Variable* v = scope.FindVar(input);
PADDLE_ENFORCE_NOT_NULL(v);
LoDTensor* tensor_parent = v->GetMutable<LoDTensor>();
v = sub_scopes[i]->FindVar(input);
PADDLE_ENFORCE_NOT_NULL(v);
LoDTensor* tensor_child = v->GetMutable<LoDTensor>();
// Resize child
DDim dim = tensor_child->dims();
dim[0] = index_[i].size();
tensor_child->Resize(dim);
tensor_child->mutable_data<float>(dim, platform::CPUPlace());
Gather<float>(dev_ctx.GetPlace(), tensor_parent, &index_tensors[i],
tensor_child);
}
}
// Step 3: run
for (int i = 0; i < 2; ++i) {
sub_net_op_[i]->Run(*sub_scopes[i], dev_ctx);
}
// Step 4: merge output results
for (int i = 0; i < 2; ++i) {
// i= 0/i for True and False branches respectively
for (auto& output : Outputs("Outs")) {
// find Tensor
Variable* v = scope.FindVar(output);
PADDLE_ENFORCE_NOT_NULL(v);
LoDTensor* tensor_parent = v->GetMutable<LoDTensor>();
v = sub_scopes[i]->FindVar(output);
PADDLE_ENFORCE_NOT_NULL(v);
LoDTensor* tensor_child = v->GetMutable<LoDTensor>();
ScatterUpdate<float>(dev_ctx.GetPlace(), tensor_child, &index_tensors[i],
tensor_parent);
}
}
}
class CondOpProtoAndCheckerMaker : public framework::OpProtoAndCheckerMaker {
public:
CondOpProtoAndCheckerMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Cond", "The condition, which is a bool vector");
AddInput("Xs", "Inputs of Subnets").AsDuplicable();
AddOutput("Outs", "Outputs of Cond_Op after merge").AsDuplicable();
AddOutput("SubScopes", "sub scopes for true and false branches");
AddOutput("IndexTensors", "Index Tensors contains indices for true/false");
AddComment(R"DOC(
Sample dependent Cond Operator:
Given Cond[i] as a 1/0 vector to indicate true/false
The equation is:
Out[i] = subnet_t[i], if Cond[i] == true
Out[i] = subnet_t[i], if Cond[i] == false
)DOC");
}
};
} // namespace operators
} // namespace paddle
REGISTER_OP_WITHOUT_GRADIENT(cond, paddle::operators::CondOp,
paddle::operators::CondOpProtoAndCheckerMaker);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "glog/logging.h"
#include "paddle/framework/ddim.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/tensor.h"
#include "paddle/operators/net_op.h"
namespace paddle {
namespace operators {
/*
* @brief CondOp is a dynamic if-else Operator
*
* It has a input tensor named cond indicating which netop each instance will
* run.
*
* if cond == 1, it will run true_net, which is a NetOp.
*
* if cond == 0, it will run false_net, which is another NetOp.
*/
class CondOp : public framework::OperatorBase {
public:
CondOp(const std::string& type, const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {
index_.resize(2);
sub_net_op_.resize(2);
}
CondOp(const CondOp& o)
: framework::OperatorBase(
static_cast<const framework::OperatorBase&>(o)) {
// TODO(yuyang18): Implement copy ctor well.
PADDLE_THROW("Not implemented");
}
void CreateScope(const framework::Scope& scope) const;
void CreateIndexTensor(const framework::Scope& scope) const;
/*
* InferShape must be called before Run.
*/
void InferShape(const framework::Scope& scope) const override;
/*
* Set True Block
*/
void set_truenet(std::unique_ptr<OperatorBase>&& net) {
sub_net_op_[0] = std::move(net);
}
/*
* Set False Block
*/
void set_falsenet(std::unique_ptr<OperatorBase>&& net) {
sub_net_op_[1] = std::move(net);
}
void Run(const framework::Scope& scope,
const platform::DeviceContext& dev_ctx) const override;
private:
// sub_net_op_[0]: subnet_t
// sub_net_op_[1]: subnet_f
std::vector<std::unique_ptr<framework::OperatorBase>> sub_net_op_;
// index_[0]: True_index;
// index_[1]: False_index;
mutable std::vector<std::vector<int>> index_;
};
} // namespace operators
} // namespace paddle
......@@ -56,7 +56,7 @@ class CosSimKernel : public framework::OpKernel {
x_norm.device(place) = x.square().sum(row_along).sqrt();
y_norm.device(place) = y.square().sum(row_along).sqrt();
if (rows_x == rows_y) {
auto xy = (x * y).sum(Eigen::array<int, 1>({1}));
auto xy = (x * y).sum(Eigen::array<int, 1>({{1}}));
z.device(place) = xy / x_norm / y_norm;
} else {
Eigen::DSizes<int, 2> bcast(rows_x, 1);
......@@ -134,7 +134,7 @@ class CosSimGradKernel : public framework::OpKernel {
out_grad_y->mutable_data<T>(context.GetPlace());
auto dy = EigenMatrix<T>::Reshape(*out_grad_y, 1);
auto grad = x / norm_prod_bcast - z_bcast * y_bcast / y_snorm_bcast;
dy.device(place) = (dz_bcast * grad).sum(Eigen::array<int, 1>({0}));
dy.device(place) = (dz_bcast * grad).sum(Eigen::array<int, 1>({{0}}));
}
}
}
......
......@@ -13,10 +13,8 @@
limitations under the License. */
#pragma once
#include <iostream>
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/math_function.h"
namespace paddle {
namespace operators {
......
......@@ -77,8 +77,6 @@ class MinusGradOp : public NetOp {
} // namespace operators
} // namespace paddle
USE_OP(scale);
USE_NO_KERNEL_OP(identity);
namespace ops = paddle::operators;
REGISTER_OP(minus, ops::MinusOp, ops::MinusOpMaker, minus_grad,
ops::MinusGradOp<float>);
......
......@@ -24,3 +24,4 @@ cc_library(device_context SRCS device_context.cc DEPS memory buddy_allocator
nv_test(device_context_test SRCS device_context_test.cc DEPS device_context gpu_info)
nv_test(cudnn_helper_test SRCS cudnn_helper_test.cc DEPS dynload_cuda)
nv_test(transform_test SRCS transform_test.cu DEPS paddle_memory place)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifndef __NVCC__
#error device_ptr_cast must be include by .cu file
#endif
#include <thrust/device_ptr.h>
namespace paddle {
namespace platform {
namespace details {
template <typename T, bool is_ptr>
struct DevicePtrCast;
template <typename T>
struct DevicePtrCast<T, true> {
using ELEM = typename std::remove_pointer<T>::type;
using RTYPE = thrust::device_ptr<ELEM>;
inline thrust::device_ptr<ELEM> operator()(ELEM* ele) const {
return thrust::device_pointer_cast(ele);
}
};
template <typename T>
struct DevicePtrCast<T, false> {
using RTYPE = T;
inline RTYPE operator()(RTYPE it) const { return it; }
};
// Cast T to thrust::device_ptr if T is a pointer.
// Otherwise, e.g., T is a iterator, return T itself.
template <typename T>
auto DevPtrCast(T t) ->
typename DevicePtrCast<T, std::is_pointer<T>::value>::RTYPE {
DevicePtrCast<T, std::is_pointer<T>::value> cast;
return cast(t);
}
} // namespace details
} // namespace platform
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/platform/enforce.h"
#include "paddle/platform/hostdevice.h"
#include "paddle/platform/place.h"
#include <algorithm>
#include <type_traits>
#ifdef __NVCC__
#include <thrust/transform.h>
#include "paddle/platform/details/device_ptr_cast.h"
#endif
namespace paddle {
namespace platform {
// Transform on host or device. It provides the same API in std library.
template <typename Place, typename InputIter, typename OutputIter,
typename UnaryOperation>
void Transform(Place place, InputIter first, InputIter last, OutputIter result,
UnaryOperation op) {
if (is_cpu_place(place)) {
std::transform(first, last, result, op);
} else {
#ifdef __NVCC__
using namespace details;
thrust::transform(DevPtrCast(first), DevPtrCast(last), DevPtrCast(result),
op);
#else
PADDLE_THROW("Do not invoke `Transform<GPUPlace>` in .cc file");
#endif
}
}
template <typename Place, typename InputIter1, typename InputIter2,
typename OutputIter, typename BinaryOperation>
void Transform(Place place, InputIter1 first1, InputIter1 last1,
InputIter2 first2, OutputIter result, BinaryOperation op) {
if (is_cpu_place(place)) {
std::transform(first1, last1, first2, result, op);
} else {
#ifdef __NVCC__
using namespace details;
thrust::transform(DevPtrCast(first1), DevPtrCast(last1), DevPtrCast(first2),
DevPtrCast(result), op);
#else
PADDLE_THROW("Do not invoke `Transform<GPUPlace>` in .cc file");
#endif
}
};
} // namespace platform
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/memory/memcpy.h"
#include "paddle/memory/memory.h"
#include "paddle/platform/transform.h"
template <typename T>
class Scale {
public:
explicit Scale(const T& scale) : scale_(scale) {}
HOSTDEVICE T operator()(const T& a) const { return a * scale_; }
private:
T scale_;
};
template <typename T>
class Multiply {
public:
HOSTDEVICE T operator()(const T& a, const T& b) const { return a * b; }
};
TEST(Transform, CPUUnary) {
using namespace paddle::platform;
float buf[4] = {0.1, 0.2, 0.3, 0.4};
Transform(CPUPlace(), buf, buf + 4, buf, Scale<float>(10));
for (int i = 0; i < 4; ++i) {
ASSERT_NEAR(buf[i], static_cast<float>(i + 1), 1e-5);
}
}
TEST(Transform, GPUUnary) {
using namespace paddle::platform;
using namespace paddle::memory;
GPUPlace gpu0(0);
float cpu_buf[4] = {0.1, 0.2, 0.3, 0.4};
float* gpu_buf = static_cast<float*>(Alloc(gpu0, sizeof(float) * 4));
Copy(gpu0, gpu_buf, CPUPlace(), cpu_buf, sizeof(cpu_buf));
Transform(gpu0, gpu_buf, gpu_buf + 4, gpu_buf, Scale<float>(10));
Copy(CPUPlace(), cpu_buf, gpu0, gpu_buf, sizeof(cpu_buf));
Free(gpu0, gpu_buf);
for (int i = 0; i < 4; ++i) {
ASSERT_NEAR(cpu_buf[i], static_cast<float>(i + 1), 1e-5);
}
}
TEST(Transform, CPUBinary) {
using namespace paddle::platform;
using namespace paddle::memory;
int buf[4] = {1, 2, 3, 4};
Transform(CPUPlace(), buf, buf + 4, buf, buf, Multiply<int>());
for (int i = 0; i < 4; ++i) {
ASSERT_EQ((i + 1) * (i + 1), buf[i]);
}
}
TEST(Transform, GPUBinary) {
using namespace paddle::platform;
using namespace paddle::memory;
int buf[4] = {1, 2, 3, 4};
GPUPlace gpu0(0);
int* gpu_buf = static_cast<int*>(Alloc(gpu0, sizeof(buf)));
Copy(gpu0, gpu_buf, CPUPlace(), buf, sizeof(buf));
Transform(gpu0, gpu_buf, gpu_buf + 4, gpu_buf, gpu_buf, Multiply<int>());
Copy(CPUPlace(), buf, gpu0, gpu_buf, sizeof(buf));
Free(gpu0, gpu_buf);
for (int i = 0; i < 4; ++i) {
ASSERT_EQ((i + 1) * (i + 1), buf[i]);
}
}
\ No newline at end of file
......@@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/framework/backward.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/cond_op.h"
#include "paddle/operators/net_op.h"
#include "paddle/operators/recurrent_op.h"
#include "paddle/platform/enforce.h"
......@@ -288,6 +289,28 @@ All parameter, weight, gradient are variables in Paddle.
[](operators::RecurrentOp &self, const operators::NetOp &net)
-> void { self.set_stepnet(net.Clone()); });
// cond_op
py::class_<operators::CondOp, OperatorBase>(m, "CondOp")
.def_static("create",
[](py::bytes protobin) -> operators::CondOp * {
OpDesc desc;
PADDLE_ENFORCE(desc.ParsePartialFromString(protobin),
"Cannot parse user input to OpDesc");
PADDLE_ENFORCE(desc.IsInitialized(),
"User OpDesc is not initialized, reason %s",
desc.InitializationErrorString());
auto cond_op = OpRegistry::CreateOp(desc);
return static_cast<operators::CondOp *>(cond_op.release());
})
.def("set_truenet",
[](operators::CondOp &self, const operators::NetOp &net) -> void {
self.set_truenet(net.Clone());
})
.def("set_falsenet",
[](operators::CondOp &self, const operators::NetOp &net) -> void {
self.set_falsenet(net.Clone());
});
m.def("unique_integer", UniqueIntegerGenerator);
m.def("is_compile_gpu", IsCompileGPU);
......
......@@ -215,5 +215,27 @@ class __RecurrentOp__(object):
return core.RecurrentOp.create(proto.SerializeToString())
class __CondOp__(object):
__proto__ = None
type = "cond"
def __init__(self):
# cache recurrent_op's proto
if self.__proto__ is None:
for op_proto in get_all_op_protos():
if op_proto.type == self.type:
self.__proto__ = op_proto
def __call__(self, *args, **kwargs):
if self.type not in args and "type" not in kwargs:
kwargs["type"] = self.type
# create proto
create_method = OpDescCreationMethod(self.__proto__)
proto = create_method(*args, **kwargs)
# create condop
return core.CondOp.create(proto.SerializeToString())
Operator = OperatorFactory() # The default global factory
RecurrentOp = __RecurrentOp__()
CondOp = __CondOp__()
import logging
import paddle.v2.framework.core as core
import unittest
import numpy as np
from paddle.v2.framework.op import Operator, CondOp
class PySimpleCond(object):
'''
A simple implementation of dynamic if-else based on numpy
'''
def __init__(self):
array = [1] * 10
for i in range(1, 10, 2):
array[i] = 0
self.cond = np.array(array)
self.x = np.ones(shape=(10, 1))
def forward(self):
self.index_t = np.where(self.cond == 1)
self.index_f = np.where(self.cond == 0)
y_t = self.x[self.index_t]
y_f = self.x[self.index_f]
y_t = y_t * 2.
y_f = y_f * (-2.)
output = np.zeros(shape=(10, 1))
output[self.index_t] = y_t
output[self.index_f] = y_f
return output
class PySimpleCondTest(unittest.TestCase):
def setUp(self):
self.condnn = PySimpleCond()
def test_forward(self):
output = self.condnn.forward()
def create_tensor(scope, name, shape, np_data):
tensor = scope.new_var(name).get_tensor()
tensor.set_dims(shape)
tensor.set(np_data, core.CPUPlace())
return tensor
class TestCondOp(unittest.TestCase):
'''
Test CondOp
equation:
cond = [True, False, True, False, ...]
y[index_t] = x[index_t] * 2.
y[index_f] = x[index_f] * -2.
outputs:
y
'''
def setUp(self):
self.py_cond = PySimpleCond()
def forward(self):
self.scope = core.Scope()
self.create_global_variables()
self.create_cond_op()
self.create_sub_net()
ctx = core.DeviceContext.create(core.CPUPlace())
self.condop.infer_shape(self.scope)
self.condop.run(self.scope, ctx)
return np.array(self.scope.find_var("Out").get_tensor())
def create_global_variables(self):
x_np_data = self.py_cond.x
create_tensor(self.scope, "X", [10, 1], x_np_data)
cond_np_data = self.py_cond.cond.astype("int32")
create_tensor(self.scope, "cond", [10, 1], cond_np_data)
self.scope.new_var("SubScopes")
self.scope.new_var("IndexTensors")
self.scope.new_var("Out")
def create_cond_op(self):
self.condop = CondOp(
Cond="cond",
Xs=["X"],
Outs=["Out"],
SubScopes="SubScopes",
IndexTensors="IndexTensors")
def create_sub_net(self):
truenet = core.Net.create()
scale_op_t = Operator("scale", X='X', Out='Out', scale=2.)
truenet.append_op(scale_op_t)
truenet.complete_add_op(True)
self.condop.set_truenet(truenet)
falsenet = core.Net.create()
scale_op_t = Operator("scale", X='X', Out='Out', scale=-2.)
falsenet.append_op(scale_op_t)
falsenet.complete_add_op(True)
self.condop.set_falsenet(falsenet)
def test_forward(self):
print 'test cond op forward'
pd_output = self.forward()
py_output = self.py_cond.forward()
print 'pd_output', pd_output
print
print 'py_output', py_output
self.assertEqual(pd_output.shape, py_output.shape)
print 'test passed'
return 0
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册