提交 39d79e64 编写于 作者: Z zchen0211

modified codes

上级 299dcb67
...@@ -78,9 +78,6 @@ class Tensor { ...@@ -78,9 +78,6 @@ class Tensor {
template <typename T> template <typename T>
inline T* mutable_data(DDim dims, platform::Place place); inline T* mutable_data(DDim dims, platform::Place place);
/*! Size of a single element in data() */
inline size_t element_size() { return holder_->element_size(); }
/*! Return the dimensions of the memory block. */ /*! Return the dimensions of the memory block. */
inline const DDim& dims() const; inline const DDim& dims() const;
...@@ -132,7 +129,6 @@ class Tensor { ...@@ -132,7 +129,6 @@ class Tensor {
virtual ~Placeholder() {} virtual ~Placeholder() {}
virtual void* ptr() const = 0; virtual void* ptr() const = 0;
virtual size_t size() const = 0; virtual size_t size() const = 0;
virtual size_t element_size() const = 0;
virtual std::type_index type() const = 0; virtual std::type_index type() const = 0;
virtual platform::Place place() const = 0; virtual platform::Place place() const = 0;
}; };
...@@ -143,8 +139,7 @@ class Tensor { ...@@ -143,8 +139,7 @@ class Tensor {
: ptr_(static_cast<T*>(memory::Alloc(place, size)), : ptr_(static_cast<T*>(memory::Alloc(place, size)),
memory::PODDeleter<T, Place>(place)), memory::PODDeleter<T, Place>(place)),
place_(place), place_(place),
size_(size), size_(size) {
element_size_(sizeof(T)) {
PADDLE_ENFORCE_NOT_NULL(ptr_, "Insufficient %s memory to allocation.", PADDLE_ENFORCE_NOT_NULL(ptr_, "Insufficient %s memory to allocation.",
(is_cpu_place(place_) ? "CPU" : "GPU")); (is_cpu_place(place_) ? "CPU" : "GPU"));
} }
...@@ -153,7 +148,6 @@ class Tensor { ...@@ -153,7 +148,6 @@ class Tensor {
virtual platform::Place place() const { return place_; } virtual platform::Place place() const { return place_; }
virtual void* ptr() const { return static_cast<void*>(ptr_.get()); } virtual void* ptr() const { return static_cast<void*>(ptr_.get()); }
virtual std::type_index type() const { return std::type_index(typeid(T)); } virtual std::type_index type() const { return std::type_index(typeid(T)); }
virtual size_t element_size() const { return element_size_; }
/*! the pointer of memory block. */ /*! the pointer of memory block. */
std::unique_ptr<T, memory::PODDeleter<T, Place>> ptr_; std::unique_ptr<T, memory::PODDeleter<T, Place>> ptr_;
...@@ -163,9 +157,6 @@ class Tensor { ...@@ -163,9 +157,6 @@ class Tensor {
/*! the size of memory block. */ /*! the size of memory block. */
size_t size_; size_t size_;
/*! the size of a single element */
size_t element_size_;
}; };
/*! holds the memory block if allocated. */ /*! holds the memory block if allocated. */
......
...@@ -36,7 +36,7 @@ TEST(Tensor, DataAssert) { ...@@ -36,7 +36,7 @@ TEST(Tensor, DataAssert) {
} catch (paddle::platform::EnforceNotMet err) { } catch (paddle::platform::EnforceNotMet err) {
caught = true; caught = true;
std::string msg = std::string msg =
"holder_ should not be null\nTenosr holds no memory. Call " "holder_ should not be null\nTensor holds no memory. Call "
"Tensor::mutable_data first."; "Tensor::mutable_data first.";
const char* what = err.what(); const char* what = err.what();
for (size_t i = 0; i < msg.length(); ++i) { for (size_t i = 0; i < msg.length(); ++i) {
...@@ -59,8 +59,6 @@ TEST(Tensor, MutableData) { ...@@ -59,8 +59,6 @@ TEST(Tensor, MutableData) {
// initialization // initialization
p1 = src_tensor.mutable_data<float>(make_ddim({1, 2, 3}), CPUPlace()); p1 = src_tensor.mutable_data<float>(make_ddim({1, 2, 3}), CPUPlace());
EXPECT_NE(p1, nullptr); EXPECT_NE(p1, nullptr);
// check tensor type
EXPECT_EQ(src_tensor.element_size(), sizeof(float));
// set src_tensor a new dim with large size // set src_tensor a new dim with large size
// momery is supposed to be re-allocated // momery is supposed to be re-allocated
p2 = src_tensor.mutable_data<float>(make_ddim({3, 4}), CPUPlace()); p2 = src_tensor.mutable_data<float>(make_ddim({3, 4}), CPUPlace());
...@@ -114,7 +112,7 @@ TEST(Tensor, ShareDataWith) { ...@@ -114,7 +112,7 @@ TEST(Tensor, ShareDataWith) {
} catch (paddle::platform::EnforceNotMet err) { } catch (paddle::platform::EnforceNotMet err) {
caught = true; caught = true;
std::string msg = std::string msg =
"holder_ should not be null\nTenosr holds no memory. Call " "holder_ should not be null\nTensor holds no memory. Call "
"Tensor::mutable_data first."; "Tensor::mutable_data first.";
const char* what = err.what(); const char* what = err.what();
for (size_t i = 0; i < msg.length(); ++i) { for (size_t i = 0; i < msg.length(); ++i) {
...@@ -276,4 +274,4 @@ TEST(Tensor, ReshapeToMatrix) { ...@@ -276,4 +274,4 @@ TEST(Tensor, ReshapeToMatrix) {
Tensor res = ReshapeToMatrix<int>(src, 2); Tensor res = ReshapeToMatrix<int>(src, 2);
ASSERT_EQ(res.dims()[0], 2 * 3); ASSERT_EQ(res.dims()[0], 2 * 3);
ASSERT_EQ(res.dims()[1], 4 * 9); ASSERT_EQ(res.dims()[1], 4 * 9);
} }
\ No newline at end of file
...@@ -80,8 +80,7 @@ endfunction() ...@@ -80,8 +80,7 @@ endfunction()
add_subdirectory(math) add_subdirectory(math)
set(DEPS_OPS set(DEPS_OPS
recurrent_op) recurrent_op
set(DEPS_OPS
cond_op) cond_op)
op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc
DEPS framework_proto tensor net_op) DEPS framework_proto tensor net_op)
......
...@@ -28,6 +28,7 @@ namespace operators { ...@@ -28,6 +28,7 @@ namespace operators {
using Scope = framework::Scope; using Scope = framework::Scope;
using Variable = framework::Variable; using Variable = framework::Variable;
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using DDim = framework::DDim; using DDim = framework::DDim;
void CondOp::CreateScope(const Scope& scope) const { void CondOp::CreateScope(const Scope& scope) const {
...@@ -41,8 +42,9 @@ void CondOp::CreateScope(const Scope& scope) const { ...@@ -41,8 +42,9 @@ void CondOp::CreateScope(const Scope& scope) const {
void CondOp::CreateIndexTensor(const Scope& scope) const { void CondOp::CreateIndexTensor(const Scope& scope) const {
auto index_tensors_var = scope.FindVar("IndexTensors"); auto index_tensors_var = scope.FindVar("IndexTensors");
PADDLE_ENFORCE(index_tensors_var != nullptr, ""); PADDLE_ENFORCE(index_tensors_var != nullptr, "");
auto& index_tensors = *index_tensors_var->GetMutable<std::vector<Tensor>>(); auto& index_tensors =
index_tensors.push_back(Tensor()); *index_tensors_var->GetMutable<std::vector<LoDTensor>>();
index_tensors.push_back(LoDTensor());
} }
void CondOp::InferShape(const Scope& scope) const { void CondOp::InferShape(const Scope& scope) const {
...@@ -65,8 +67,8 @@ void CondOp::InferShape(const Scope& scope) const { ...@@ -65,8 +67,8 @@ void CondOp::InferShape(const Scope& scope) const {
for (auto& input : Inputs("Xs")) { for (auto& input : Inputs("Xs")) {
// Create a new tensor in sub-scope for input-type tensor // Create a new tensor in sub-scope for input-type tensor
Variable* v = sub_scopes[i]->NewVar(input); Variable* v = sub_scopes[i]->NewVar(input);
Tensor* sub_input = v->GetMutable<Tensor>(); LoDTensor* sub_input = v->GetMutable<LoDTensor>();
sub_input->Resize(scope.FindVar(input)->GetMutable<Tensor>()->dims()); sub_input->Resize(scope.FindVar(input)->GetMutable<LoDTensor>()->dims());
} }
for (auto& output : (*sub_net_op_[i]).Outputs()) { for (auto& output : (*sub_net_op_[i]).Outputs()) {
...@@ -80,33 +82,40 @@ void CondOp::InferShape(const Scope& scope) const { ...@@ -80,33 +82,40 @@ void CondOp::InferShape(const Scope& scope) const {
} }
for (auto& output : Outputs("Outs")) { for (auto& output : Outputs("Outs")) {
Tensor* tensor_t_out = sub_scopes[0]->FindVar(output)->GetMutable<Tensor>(); LoDTensor* tensor_t_out =
PADDLE_ENFORCE_NOT_NULL(tensor_t_out, "True output should be NULL"); sub_scopes[0]->FindVar(output)->GetMutable<LoDTensor>();
Tensor* tensor_f_out = sub_scopes[1]->FindVar(output)->GetMutable<Tensor>(); PADDLE_ENFORCE_NOT_NULL(tensor_t_out, "True output should not be NULL");
PADDLE_ENFORCE_NOT_NULL(tensor_f_out, "True output should be NULL"); LoDTensor* tensor_f_out =
sub_scopes[1]->FindVar(output)->GetMutable<LoDTensor>();
PADDLE_ENFORCE_NOT_NULL(tensor_f_out, "False output should not be NULL");
auto* tensor_out_var = scope.FindVar(output); auto* tensor_out_var = scope.FindVar(output);
PADDLE_ENFORCE_NOT_NULL(tensor_out_var, "Output not found"); PADDLE_ENFORCE_NOT_NULL(tensor_out_var, "Output not found");
Tensor* tensor_out = tensor_out_var->GetMutable<Tensor>(); LoDTensor* tensor_out = tensor_out_var->GetMutable<LoDTensor>();
PADDLE_ENFORCE_NOT_NULL(tensor_t_out, "True output should be NULL"); PADDLE_ENFORCE_NOT_NULL(tensor_t_out,
"True output tensor should not be NULL");
// check output size should be same // check output size should be same
PADDLE_ENFORCE_EQ(tensor_t_out->dims(), tensor_f_out->dims(), PADDLE_ENFORCE_EQ(tensor_t_out->dims(), tensor_f_out->dims(),
"Outputs not of the same shape"); "Outputs not of the same shape");
tensor_out->Resize(tensor_t_out->dims()); tensor_out->Resize(tensor_t_out->dims());
tensor_out->mutable_data<float>(tensor_out->dims(), platform::CPUPlace()); // tensor_out->mutable_data<float>(tensor_out->dims(),
// platform::CPUPlace());
tensor_out->mutable_data<float>(platform::CPUPlace());
} }
} }
void CondOp::Run(const Scope& scope, void CondOp::Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const { const platform::DeviceContext& dev_ctx) const {
auto sub_scopes = scope.FindVar("SubScopes")->Get<std::vector<Scope*>>(); auto* sub_scopes_var = scope.FindVar("SubScopes");
auto index_tensors = auto sub_scopes = sub_scopes_var->Get<std::vector<Scope*>>();
scope.FindVar("IndexTensors")->Get<std::vector<Tensor>>(); auto* index_tensors_var = scope.FindVar("IndexTensors");
auto index_tensors = index_tensors_var->Get<std::vector<LoDTensor>>();
std::string cond_name = Input("Cond"); std::string cond_name = Input("Cond");
Variable* cond_var = scope.FindVar(cond_name); Variable* cond_var = scope.FindVar(cond_name);
PADDLE_ENFORCE_NOT_NULL(cond_var); PADDLE_ENFORCE_NOT_NULL(cond_var);
const Tensor* cond = cond_var->GetMutable<Tensor>(); const LoDTensor* cond = cond_var->GetMutable<LoDTensor>();
// Step 1: get the true/false index at runtime // Step 1: get the true/false index at runtime
// index_[0]: vector<int>, contains all index for cond[i] == true // index_[0]: vector<int>, contains all index for cond[i] == true
...@@ -139,11 +148,11 @@ void CondOp::Run(const Scope& scope, ...@@ -139,11 +148,11 @@ void CondOp::Run(const Scope& scope,
// find Tensor // find Tensor
Variable* v = scope.FindVar(input); Variable* v = scope.FindVar(input);
PADDLE_ENFORCE_NOT_NULL(v); PADDLE_ENFORCE_NOT_NULL(v);
Tensor* tensor_parent = v->GetMutable<Tensor>(); LoDTensor* tensor_parent = v->GetMutable<LoDTensor>();
v = sub_scopes[i]->FindVar(input); v = sub_scopes[i]->FindVar(input);
PADDLE_ENFORCE_NOT_NULL(v); PADDLE_ENFORCE_NOT_NULL(v);
Tensor* tensor_child = v->GetMutable<Tensor>(); LoDTensor* tensor_child = v->GetMutable<LoDTensor>();
// Resize child // Resize child
DDim dim = tensor_child->dims(); DDim dim = tensor_child->dims();
...@@ -157,7 +166,9 @@ void CondOp::Run(const Scope& scope, ...@@ -157,7 +166,9 @@ void CondOp::Run(const Scope& scope,
} }
// Step 3: run // Step 3: run
for (int i = 0; i < 2; ++i) sub_net_op_[i]->Run(*sub_scopes[i], dev_ctx); for (int i = 0; i < 2; ++i) {
sub_net_op_[i]->Run(*sub_scopes[i], dev_ctx);
}
// Step 4: merge output results // Step 4: merge output results
for (int i = 0; i < 2; ++i) { for (int i = 0; i < 2; ++i) {
...@@ -166,11 +177,11 @@ void CondOp::Run(const Scope& scope, ...@@ -166,11 +177,11 @@ void CondOp::Run(const Scope& scope,
// find Tensor // find Tensor
Variable* v = scope.FindVar(output); Variable* v = scope.FindVar(output);
PADDLE_ENFORCE_NOT_NULL(v); PADDLE_ENFORCE_NOT_NULL(v);
Tensor* tensor_parent = v->GetMutable<Tensor>(); LoDTensor* tensor_parent = v->GetMutable<LoDTensor>();
v = sub_scopes[i]->FindVar(output); v = sub_scopes[i]->FindVar(output);
PADDLE_ENFORCE_NOT_NULL(v); PADDLE_ENFORCE_NOT_NULL(v);
Tensor* tensor_child = v->GetMutable<Tensor>(); LoDTensor* tensor_child = v->GetMutable<LoDTensor>();
ScatterUpdate<float>(dev_ctx.GetPlace(), tensor_child, &index_tensors[i], ScatterUpdate<float>(dev_ctx.GetPlace(), tensor_child, &index_tensors[i],
tensor_parent); tensor_parent);
...@@ -192,7 +203,9 @@ class CondOpProtoAndCheckerMaker : public framework::OpProtoAndCheckerMaker { ...@@ -192,7 +203,9 @@ class CondOpProtoAndCheckerMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC( AddComment(R"DOC(
Sample dependent Cond Operator: Sample dependent Cond Operator:
The equation is: Out[i] = subnet_t[i], if Cond[i] == true Given Cond[i] as a 1/0 vector to indicate true/false
The equation is:
Out[i] = subnet_t[i], if Cond[i] == true
Out[i] = subnet_t[i], if Cond[i] == false Out[i] = subnet_t[i], if Cond[i] == false
)DOC"); )DOC");
} }
......
...@@ -24,6 +24,17 @@ limitations under the License. */ ...@@ -24,6 +24,17 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
/*
* @brief CondOp is a dynamic if-else Operator
*
* It has a input tensor named cond indicating which netop each instance will
* run.
*
* if cond == 1, it will run true_net, which is a NetOp.
*
* if cond == 0, it will run false_net, which is another NetOp.
*/
class CondOp : public framework::OperatorBase { class CondOp : public framework::OperatorBase {
public: public:
CondOp(const std::string& type, const framework::VariableNameMap& inputs, CondOp(const std::string& type, const framework::VariableNameMap& inputs,
...@@ -45,18 +56,18 @@ class CondOp : public framework::OperatorBase { ...@@ -45,18 +56,18 @@ class CondOp : public framework::OperatorBase {
void CreateIndexTensor(const framework::Scope& scope) const; void CreateIndexTensor(const framework::Scope& scope) const;
/** /*
* InferShape must be called before Run. * InferShape must be called before Run.
*/ */
void InferShape(const framework::Scope& scope) const override; void InferShape(const framework::Scope& scope) const override;
// Set True Block // Set True Block
void set_truenet(std::unique_ptr<OperatorBase> net) { void set_truenet(std::unique_ptr<OperatorBase>&& net) {
sub_net_op_[0] = std::move(net); sub_net_op_[0] = std::move(net);
} }
// Set False Block // Set False Block
void set_falsenet(std::unique_ptr<OperatorBase> net) { void set_falsenet(std::unique_ptr<OperatorBase>&& net) {
sub_net_op_[1] = std::move(net); sub_net_op_[1] = std::move(net);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册