diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index d8812d7743381a1c2e4d76d22a234680b34b2850..7dc9d5c804a348e310989b0aca0683ce383c5447 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -44,7 +44,7 @@ add_custom_command(TARGET framework_py_proto POST_BUILD cc_library(backward SRCS backward.cc DEPS net_op) cc_test(backward_test SRCS backward_test.cc DEPS backward recurrent_op device_context) -cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto ${GLOB_OP_LIB}) +cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto backward ${GLOB_OP_LIB}) if(WITH_GPU) nv_test(executor_test SRCS executor_test.cc DEPS executor) else() diff --git a/paddle/framework/executor_test.cc b/paddle/framework/executor_test.cc index e8ea09b77dc4f18841340cee09da61bc2095302d..7ce472ed2f3e2c725e79f4084177b562c1d1646c 100644 --- a/paddle/framework/executor_test.cc +++ b/paddle/framework/executor_test.cc @@ -16,6 +16,7 @@ limitations under the License. */ #include #include "gtest/gtest.h" #include "paddle/framework/attribute.h" +#include "paddle/framework/backward.h" #include "paddle/framework/block_desc.h" #include "paddle/framework/grad_op_builder.h" #include "paddle/framework/op_desc.h" @@ -27,6 +28,7 @@ USE_OP(gaussian_random); USE_OP(feed); USE_OP(fetch); USE_OP(mul); +USE_OP(squared_l2_distance); using std::string; using namespace paddle::platform; @@ -170,10 +172,16 @@ class ExecutorTesterRandom : public ::testing::Test { root_block); AddOp("mul", {{"X", {"b"}}, {"Y", {"w2"}}}, {{"Out", {"a_out"}}}, {}, root_block); - - AddOp("fetch", {{"Input", {"a_out"}}}, {}, - {{"dims", std::vector{input_dim, batch_size}}, {"col", 1}}, + AddOp("squared_l2_distance", {{"X", {"a"}}, {"Y", {"a_out"}}}, + {{"Out", {"l2_distance"}}, {"sub_result", {"l2_distance_sub"}}}, {}, root_block); + + AppendBackward(pdesc_, {}); + // AddOp("fetch", {{"Input", {"sub_result"}}}, {}, + // {{"dims", std::vector{input_dim, batch_size}}, {"col", 0}}, + // root_block); + AddOp("fetch", {{"Input", {"l2_distance"}}}, {}, + {{"dims", std::vector{batch_size}}, {"col", 1}}, root_block); } protected: