From 6e7666f199ab1849e37c4f2e1e2570316dcf5c04 Mon Sep 17 00:00:00 2001 From: Yang Yang Date: Sun, 8 Oct 2017 05:36:19 +0000 Subject: [PATCH] before backward --- paddle/framework/CMakeLists.txt | 2 +- paddle/framework/executor_test.cc | 14 +++++++++++--- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index d8812d77433..7dc9d5c804a 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -44,7 +44,7 @@ add_custom_command(TARGET framework_py_proto POST_BUILD cc_library(backward SRCS backward.cc DEPS net_op) cc_test(backward_test SRCS backward_test.cc DEPS backward recurrent_op device_context) -cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto ${GLOB_OP_LIB}) +cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto backward ${GLOB_OP_LIB}) if(WITH_GPU) nv_test(executor_test SRCS executor_test.cc DEPS executor) else() diff --git a/paddle/framework/executor_test.cc b/paddle/framework/executor_test.cc index e8ea09b77dc..7ce472ed2f3 100644 --- a/paddle/framework/executor_test.cc +++ b/paddle/framework/executor_test.cc @@ -16,6 +16,7 @@ limitations under the License. */ #include #include "gtest/gtest.h" #include "paddle/framework/attribute.h" +#include "paddle/framework/backward.h" #include "paddle/framework/block_desc.h" #include "paddle/framework/grad_op_builder.h" #include "paddle/framework/op_desc.h" @@ -27,6 +28,7 @@ USE_OP(gaussian_random); USE_OP(feed); USE_OP(fetch); USE_OP(mul); +USE_OP(squared_l2_distance); using std::string; using namespace paddle::platform; @@ -170,10 +172,16 @@ class ExecutorTesterRandom : public ::testing::Test { root_block); AddOp("mul", {{"X", {"b"}}, {"Y", {"w2"}}}, {{"Out", {"a_out"}}}, {}, root_block); - - AddOp("fetch", {{"Input", {"a_out"}}}, {}, - {{"dims", std::vector{input_dim, batch_size}}, {"col", 1}}, + AddOp("squared_l2_distance", {{"X", {"a"}}, {"Y", {"a_out"}}}, + {{"Out", {"l2_distance"}}, {"sub_result", {"l2_distance_sub"}}}, {}, root_block); + + AppendBackward(pdesc_, {}); + // AddOp("fetch", {{"Input", {"sub_result"}}}, {}, + // {{"dims", std::vector{input_dim, batch_size}}, {"col", 0}}, + // root_block); + AddOp("fetch", {{"Input", {"l2_distance"}}}, {}, + {{"dims", std::vector{batch_size}}, {"col", 1}}, root_block); } protected: -- GitLab