You need to sign in or sign up before continuing.
提交 c7db6e8d 编写于 作者: Z zchen0211

cond op passed

上级 b8e75c1f
......@@ -13,15 +13,175 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/cond_op.h"
#include <cstring>
#include <sstream>
#include "paddle/framework/op_registry.h"
#include "paddle/operators/gather.h"
#include "paddle/operators/net_op.h"
#include "paddle/operators/scatter.h"
namespace paddle {
namespace operators {
class CondOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
using Scope = framework::Scope;
using Variable = framework::Variable;
using Tensor = framework::Tensor;
using DDim = framework::DDim;
void CondOp::CreateScope(const Scope& scope) const {
auto sub_scopes_var = scope.FindVar("SubScopes");
PADDLE_ENFORCE(sub_scopes_var != nullptr, "");
auto sub_scopes = sub_scopes_var->GetMutable<std::vector<Scope*>>();
auto& sub_scope = scope.NewScope();
sub_scopes->push_back(&sub_scope);
}
void CondOp::CreateIndexTensor(const Scope& scope) const {
auto index_tensors_var = scope.FindVar("IndexTensors");
PADDLE_ENFORCE(index_tensors_var != nullptr, "");
auto& index_tensors = *index_tensors_var->GetMutable<std::vector<Tensor>>();
index_tensors.push_back(Tensor());
}
void CondOp::InferShape(const Scope& scope) const {
auto sub_scopes_var = scope.FindVar("SubScopes");
PADDLE_ENFORCE_NOT_NULL(sub_scopes_var);
auto& sub_scopes = *sub_scopes_var->GetMutable<std::vector<Scope*>>();
for (int i = 0; i < 2; ++i) {
// Create two sub scopes for true and false branches
// sub_scopes[0] for the true branch and sub_scopes[1] for the false
// branch
CreateScope(scope);
// Create two tensors for true and false indices
// index_tensors[0] for the true branch and index_tensors[1] for the false
// branch
CreateIndexTensor(scope);
PADDLE_ENFORCE(!Inputs("Xs").empty(), "Inputs can't be empty");
for (auto& input : Inputs("Xs")) {
// Create a new tensor in sub-scope for input-type tensor
Variable* v = sub_scopes[i]->NewVar(input);
Tensor* sub_input = v->GetMutable<Tensor>();
sub_input->Resize(scope.FindVar(input)->GetMutable<Tensor>()->dims());
}
for (auto& output : (*sub_net_op_[i]).Outputs()) {
for (auto& var_name : output.second) {
sub_scopes[i]->NewVar(var_name);
}
}
// each net calls InferShape
sub_net_op_[i]->InferShape(*sub_scopes[i]);
}
for (auto& output : Outputs("Outs")) {
Tensor* tensor_t_out = sub_scopes[0]->FindVar(output)->GetMutable<Tensor>();
PADDLE_ENFORCE_NOT_NULL(tensor_t_out, "True output should be NULL");
Tensor* tensor_f_out = sub_scopes[1]->FindVar(output)->GetMutable<Tensor>();
PADDLE_ENFORCE_NOT_NULL(tensor_f_out, "True output should be NULL");
auto* tensor_out_var = scope.FindVar(output);
PADDLE_ENFORCE_NOT_NULL(tensor_out_var, "Output not found");
Tensor* tensor_out = tensor_out_var->GetMutable<Tensor>();
PADDLE_ENFORCE_NOT_NULL(tensor_t_out, "True output should be NULL");
// check output size should be same
PADDLE_ENFORCE_EQ(tensor_t_out->dims(), tensor_f_out->dims(),
"Outputs not of the same shape");
tensor_out->Resize(tensor_t_out->dims());
tensor_out->mutable_data<float>(tensor_out->dims(), platform::CPUPlace());
}
}
void CondOp::Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const {
auto sub_scopes = scope.FindVar("SubScopes")->Get<std::vector<Scope*>>();
auto index_tensors =
scope.FindVar("IndexTensors")->Get<std::vector<Tensor>>();
std::string cond_name = Input("Cond");
Variable* cond_var = scope.FindVar(cond_name);
PADDLE_ENFORCE_NOT_NULL(cond_var);
const Tensor* cond = cond_var->GetMutable<Tensor>();
// Step 1: get the true/false index at runtime
// index_[0]: vector<int>, contains all index for cond[i] == true
// index_[1]: vector<int>, contains all index for cond[i] == false
for (int i = 0; i < 2; ++i) index_[i].clear();
const int* cond_data = cond->data<int>();
for (int i = 0; i < cond->dims()[0]; ++i) {
if (cond_data[i])
index_[0].push_back(i);
else
index_[1].push_back(i);
}
// put index_[0] and index_[1] into two tensors:
// index_tensor_[0] and index_tensor_[1]
DDim dim = paddle::framework::make_ddim({0});
for (int i = 0; i < 2; ++i) {
dim[0] = index_[i].size();
int* tmp_ptr =
index_tensors[i].mutable_data<int>(dim, platform::CPUPlace());
index_tensors[i].Resize(dim);
memcpy(tmp_ptr, index_[i].data(), dim[0] * sizeof(int));
}
// Step 2: collect data by calling gather
for (int i = 0; i < 2; ++i) {
// i= 0/i for True and False branches respectively
for (auto& input : Inputs("Xs")) {
// find Tensor
Variable* v = scope.FindVar(input);
PADDLE_ENFORCE_NOT_NULL(v);
Tensor* tensor_parent = v->GetMutable<Tensor>();
v = sub_scopes[i]->FindVar(input);
PADDLE_ENFORCE_NOT_NULL(v);
Tensor* tensor_child = v->GetMutable<Tensor>();
// Resize child
DDim dim = tensor_child->dims();
dim[0] = index_[i].size();
tensor_child->Resize(dim);
tensor_child->mutable_data<float>(dim, platform::CPUPlace());
Gather<float>(dev_ctx.GetPlace(), tensor_parent, &index_tensors[i],
tensor_child);
}
}
// Step 3: run
for (int i = 0; i < 2; ++i) sub_net_op_[i]->Run(*sub_scopes[i], dev_ctx);
// Step 4: merge output results
for (int i = 0; i < 2; ++i) {
// i= 0/i for True and False branches respectively
for (auto& output : Outputs("Outs")) {
// find Tensor
Variable* v = scope.FindVar(output);
PADDLE_ENFORCE_NOT_NULL(v);
Tensor* tensor_parent = v->GetMutable<Tensor>();
v = sub_scopes[i]->FindVar(output);
PADDLE_ENFORCE_NOT_NULL(v);
Tensor* tensor_child = v->GetMutable<Tensor>();
ScatterUpdate<float>(dev_ctx.GetPlace(), tensor_child, &index_tensors[i],
tensor_parent);
}
}
}
class CondOpProtoAndCheckerMaker : public framework::OpProtoAndCheckerMaker {
public:
CondOpProtoAndCheckerMaker(OpProto *proto, OpAttrChecker *op_checker)
CondOpProtoAndCheckerMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Cond", "The condition, which is a bool vector");
AddInput("Xs", "Inputs of Subnets").AsDuplicable();
......@@ -41,5 +201,5 @@ Out[i] = subnet_t[i], if Cond[i] == false
} // namespace operators
} // namespace paddle
REGISTER_OP_WITHOUT_GRADIENT(cond_op, paddle::operators::CondOp,
REGISTER_OP_WITHOUT_GRADIENT(cond, paddle::operators::CondOp,
paddle::operators::CondOpProtoAndCheckerMaker);
......@@ -19,22 +19,19 @@ limitations under the License. */
#include "paddle/framework/eigen.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/tensor.h"
#include "paddle/operators/gather.h"
#include "paddle/operators/scatter.h"
#include "paddle/operators/net_op.h"
namespace paddle {
namespace operators {
using namespace paddle::framework;
class CondOp : public OperatorBase {
class CondOp : public framework::OperatorBase {
public:
CondOp(const std::string& type, const VariableNameMap& inputs,
const VariableNameMap& outputs, const AttributeMap& attrs)
CondOp(const std::string& type, const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {
index_.resize(2);
sub_net_op_.resize(2);
LOG(INFO) << "Initialization Done.";
}
CondOp(const CondOp& o)
......@@ -44,87 +41,14 @@ class CondOp : public OperatorBase {
PADDLE_THROW("Not implemented");
}
void CreateScope(const Scope& scope) const {
auto sub_scopes_var = scope.FindVar("SubScopes");
PADDLE_ENFORCE(sub_scopes_var != nullptr, "");
auto sub_scopes = sub_scopes_var->GetMutable<std::vector<Scope*>>();
auto& sub_scope = scope.NewScope();
sub_scopes->push_back(&sub_scope);
}
void CreateScope(const framework::Scope& scope) const;
void CreateIndexTensor(const Scope& scope) const {
auto index_tensors_var = scope.FindVar("IndexTensors");
PADDLE_ENFORCE(index_tensors_var != nullptr, "");
auto& index_tensors =
*index_tensors_var->GetMutable<std::vector<Tensor*>>();
Tensor index_tensor;
index_tensors.push_back(&index_tensor);
}
void CreateIndexTensor(const framework::Scope& scope) const;
/**
* InferShape must be called before Run.
*/
void InferShape(const framework::Scope& scope) const override {
auto sub_scopes_var = scope.FindVar("SubScopes");
PADDLE_ENFORCE_NOT_NULL(sub_scopes_var);
auto& sub_scopes = *sub_scopes_var->GetMutable<std::vector<Scope*>>();
// auto& index_tensors =
// *scope.FindVar("IndexTensors")->GetMutable<std::vector<Tensor*>>();
for (int i = 0; i < 2; ++i) {
// Create two sub scopes for true and false branches
// sub_scopes[0] for the true branch and sub_scopes[1] for the false
// branch
CreateScope(scope);
// Create two tensors for true and false indices
// index_tensors[0] for the true branch and index_tensors[1] for the false
// branch
CreateIndexTensor(scope);
for (auto& input : Inputs("Xs")) {
// Create a new tensor in sub-scope for input-type tensor
Variable* v = sub_scopes[i]->NewVar(input);
Tensor* sub_input = v->GetMutable<Tensor>();
sub_input->Resize(scope.FindVar(input)->GetMutable<Tensor>()->dims());
}
// Inputs that do not require tailoring
/*for (auto& input : (*sub_net_op_[i]).Inputs()) {
// weights are located in the parent scope rather than sub scope
for (auto& var_name : input.second) {
if (!sub_scopes[i]->FindVar(var_name)) {
sub_scopes[i]->NewVar(var_name)->GetMutable<Tensor>();
}
}
}*/
// Outputs
for (auto& output : (*sub_net_op_[i]).Outputs()) {
for (auto& var_name : output.second) {
sub_scopes[i]->NewVar(var_name);
}
}
// each net calls InferShape
LOG(INFO) << "OK 3";
sub_net_op_[i]->InferShape(*sub_scopes[i]);
LOG(INFO) << "OK 4";
}
for (auto& output : Outputs("Outs")) {
Tensor* tensor_t_out =
sub_scopes[0]->FindVar(output)->GetMutable<Tensor>();
Tensor* tensor_f_out =
sub_scopes[1]->FindVar(output)->GetMutable<Tensor>();
Tensor* tensor_out = scope.FindVar(output)->GetMutable<Tensor>();
// check output size should be same
PADDLE_ENFORCE_EQ(tensor_t_out->dims(), tensor_f_out->dims(),
"Outputs not of the same shape");
tensor_out->Resize(tensor_t_out->dims());
}
LOG(INFO) << "OK 5";
}
void InferShape(const framework::Scope& scope) const override;
// Set True Block
void set_truenet(std::unique_ptr<OperatorBase> net) {
......@@ -137,74 +61,7 @@ class CondOp : public OperatorBase {
}
void Run(const framework::Scope& scope,
const platform::DeviceContext& dev_ctx) const override {
auto sub_scopes = scope.FindVar("SubScopes")->Get<std::vector<Scope*>>();
auto index_tensors =
scope.FindVar("IndexTensors")->Get<std::vector<Tensor*>>();
std::string cond_name = Input("Cond");
Variable* cond_var = scope.FindVar(cond_name);
PADDLE_ENFORCE_NOT_NULL(cond_var)
const Tensor* cond = cond_var->GetMutable<Tensor>();
// Step 1: get the true/false index at runtime
// index_[0]: vector<int>, contains all index for cond[i] == true
// index_[1]: vector<int>, contains all index for cond[i] == false
for (int i = 0; i < 2; ++i) index_[i].clear();
const bool* cond_data = cond->data<bool>();
for (int i = 0; i < cond->dims()[0]; ++i) {
if (cond_data[i])
index_[0].push_back(i);
else
index_[1].push_back(i);
}
// put index_[0] and index_[1] into two tensors:
// index_tensor_[0] and index_tensor_[1]
framework::DDim dim = paddle::framework::make_ddim({0});
for (int i = 0; i < 2; ++i) {
dim[0] = index_[i].size();
int* tmp_ptr =
index_tensors[i]->mutable_data<int>(dim, platform::CPUPlace());
index_tensors[i]->Resize(dim);
memcpy(tmp_ptr, index_[i].data(), dim[0] * sizeof(int));
}
// Step 2: collect data by calling gather
for (int i = 0; i < 2; ++i) {
// i= 0/i for True and False branches respectively
for (auto& input : Inputs("Xs")) {
// find Tensor
// Tensor* tensor_parent = scope.FindVar(input)->GetMutable<Tensor>();
Variable* v = scope.FindVar(input);
Tensor* tensor_parent = v->GetMutable<Tensor>();
// Tensor* tensor_child =
// sub_scope_[i].FindVar(input)->GetMutable<Tensor>();
v = sub_scopes[i]->FindVar(input);
Tensor* tensor_child = v->GetMutable<Tensor>();
Gather<float>(dev_ctx.GetPlace(), tensor_parent, index_tensors[i],
tensor_child);
}
}
// Step 3: run
for (int i = 0; i < 2; ++i) sub_net_op_[i]->Run(*sub_scopes[i], dev_ctx);
// Step 4: merge output results
for (int i = 0; i < 2; ++i) {
// i= 0/i for True and False branches respectively
// for (auto& output : GetAttr<std::vector<std::string>>("sub_outputs")) {
for (auto& output : Outputs("Outs")) {
// find Tensor
Variable* v = scope.FindVar(output);
Tensor* tensor_parent = v->GetMutable<Tensor>();
v = sub_scopes[i]->FindVar(output);
Tensor* tensor_child = v->GetMutable<Tensor>();
ScatterUpdate<float>(dev_ctx.GetPlace(), tensor_child, index_tensors[i],
tensor_parent);
}
}
}
const platform::DeviceContext& dev_ctx) const override;
private:
// sub_net_op_[0]: subnet_t
......@@ -216,17 +73,5 @@ class CondOp : public OperatorBase {
mutable std::vector<std::vector<int>> index_;
};
/*
class CondGradientOp final : public OperatorBase {
public:
void Init() override;
virtual void InferShape(const std::shared_ptr<Scope>& scope) const
override;
virtual void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const override;
};*/
} // namespace operators
} // namespace paddle
......@@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/framework/backward.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/cond_op.h"
#include "paddle/operators/net_op.h"
#include "paddle/operators/recurrent_op.h"
#include "paddle/platform/enforce.h"
......
......@@ -217,7 +217,7 @@ class __RecurrentOp__(object):
class __CondOp__(object):
__proto__ = None
type = 'cond_op'
type = "cond"
def __init__(self):
# cache recurrent_op's proto
......@@ -227,8 +227,8 @@ class __CondOp__(object):
self.__proto__ = op_proto
def __call__(self, *args, **kwargs):
if self.type not in args and 'type' not in kwargs:
kwargs['type'] = self.type
if self.type not in args and "type" not in kwargs:
kwargs["type"] = self.type
# create proto
create_method = OpDescCreationMethod(self.__proto__)
proto = create_method(*args, **kwargs)
......
......@@ -27,6 +27,7 @@ py_test(test_operator SRCS test_operator.py)
py_test(test_gaussian_random_op SRCS test_gaussian_random_op.py)
py_test(test_uniform_random_op SRCS test_uniform_random_op.py)
py_test(test_recurrent_op SRCS test_recurrent_op.py)
py_test(test_cond_op SRCS test_cond_op.py)
py_test(test_sgd_op SRCS test_sgd_op.py)
py_test(test_gradient_checker SRCS test_gradient_checker.py)
py_test(test_lookup_table SRCS test_lookup_table.py)
......
......@@ -11,15 +11,15 @@ class PySimpleCond(object):
'''
def __init__(self):
array = [True] * 10
array = [1] * 10
for i in range(1, 10, 2):
array[i] = False
array[i] = 0
self.cond = np.array(array)
self.x = np.ones(shape=(10, 1))
def forward(self):
self.index_t = np.where(self.cond)
self.index_f = np.where(self.cond == False)
self.index_t = np.where(self.cond == 1)
self.index_f = np.where(self.cond == 0)
y_t = self.x[self.index_t]
y_f = self.x[self.index_f]
y_t = y_t * 2.
......@@ -36,7 +36,6 @@ class PySimpleCondTest(unittest.TestCase):
def test_forward(self):
output = self.condnn.forward()
print 'output', output
def create_tensor(scope, name, shape, np_data):
......@@ -67,47 +66,50 @@ class TestCondOp(unittest.TestCase):
self.create_cond_op()
self.create_sub_net()
ctx = core.DeviceContext.create(core.CPUPlace())
print 'running infer shape'
print self.scope.find_var("SubScopes")
self.condop.infer_shape(self.scope)
print 'ok 2'
self.condop.run(self.scope, ctx)
print 'ok 3'
return np.array(self.scope.find_var("Outs").get_tensor())
return np.array(self.scope.find_var("Out").get_tensor())
def create_global_variables(self):
x_np_data = self.py_cond.x
create_tensor(self.scope, "x", [10, 1], x_np_data)
cond_np_data = self.py_cond.cond
create_tensor(self.scope, "cond", [10, 1], x_np_data)
create_tensor(self.scope, "X", [10, 1], x_np_data)
cond_np_data = self.py_cond.cond.astype("int32")
create_tensor(self.scope, "cond", [10, 1], cond_np_data)
self.scope.new_var("SubScopes")
self.scope.new_var("IndexTensors")
self.scope.new_var("Outs")
self.scope.new_var("Out")
def create_cond_op(self):
self.condop = CondOp(
Cond="cond",
Xs=["x"],
Outs=['Out_final'],
Xs=["X"],
Outs=["Out"],
SubScopes="SubScopes",
IndexTensors="IndexTensors")
def create_sub_net(self):
truenet = core.Net.create()
scale_op_t = Operator("scale", X='X', Y='Out', scale=2.)
scale_op_t = Operator("scale", X='X', Out='Out', scale=2.)
truenet.append_op(scale_op_t)
truenet.complete_add_op(True)
self.condop.set_truenet(truenet)
falsenet = core.Net.create()
scale_op_t = Operator("scale", X='X', Y='Out', scale=-2.)
scale_op_t = Operator("scale", X='X', Out='Out', scale=-2.)
falsenet.append_op(scale_op_t)
falsenet.complete_add_op(True)
self.condop.set_falsenet(falsenet)
def test_forward(self):
print 'test cond op forward'
py_output = self.forward()
pd_output = self.forward()
py_output = self.py_cond.forward()
print 'pd_output', pd_output
print
print 'py_output', py_output
self.assertEqual(pd_output.shape, py_output.shape)
print 'test passed'
return 0
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册