diff --git a/test/dm/test-main.c b/test/dm/test-main.c
index 10d2706377fafc70419eaaff39133b26191897aa..88ef267458e57d556a32fa56114ae2f73cfdeecf 100644
--- a/test/dm/test-main.c
+++ b/test/dm/test-main.c
@@ -22,7 +22,7 @@ struct unit_test_state global_dm_test_state;
 static struct dm_test_state _global_priv_dm_test_state;
 
 /* Get ready for testing */
-static int dm_test_init(struct unit_test_state *uts)
+static int dm_test_init(struct unit_test_state *uts, bool of_live)
 {
 	struct dm_test_state *dms = uts->priv;
 
@@ -31,7 +31,11 @@ static int dm_test_init(struct unit_test_state *uts)
 	memset(dm_testdrv_op_count, '\0', sizeof(dm_testdrv_op_count));
 	state_reset_for_test(state_get_current());
 
-	ut_assertok(dm_init(false));
+#ifdef CONFIG_OF_LIVE
+	/* Determine whether to make the live tree available */
+	gd->of_root = of_live ? uts->of_root : NULL;
+#endif
+	ut_assertok(dm_init(of_live));
 	dms->root = dm_root();
 
 	return 0;
@@ -72,13 +76,15 @@ static int dm_test_destroy(struct unit_test_state *uts)
 	return 0;
 }
 
-static int dm_do_test(struct unit_test_state *uts, struct unit_test *test)
+static int dm_do_test(struct unit_test_state *uts, struct unit_test *test,
+		      bool of_live)
 {
 	struct sandbox_state *state = state_get_current();
 	const char *fname = strrchr(test->file, '/') + 1;
 
-	printf("Test: %s: %s\n", test->name, fname);
-	ut_assertok(dm_test_init(uts));
+	printf("Test: %s: %s%s\n", test->name, fname,
+	       !of_live ? " (flat tree)" : "");
+	ut_assertok(dm_test_init(uts, of_live));
 
 	uts->start = mallinfo();
 	if (test->flags & DM_TESTF_SCAN_PDATA)
@@ -109,10 +115,10 @@ static int dm_test_main(const char *test_name)
 	struct unit_test *tests = ll_entry_start(struct unit_test, dm_test);
 	const int n_ents = ll_entry_count(struct unit_test, dm_test);
 	struct unit_test_state *uts = &global_dm_test_state;
-	uts->priv = &_global_priv_dm_test_state;
 	struct unit_test *test;
 	int run_count;
 
+	uts->priv = &_global_priv_dm_test_state;
 	uts->fail_count = 0;
 
 	/*
@@ -129,6 +135,9 @@ static int dm_test_main(const char *test_name)
 		printf("Running %d driver model tests\n", n_ents);
 
 	run_count = 0;
+#ifdef CONFIG_OF_LIVE
+	uts->of_root = gd->of_root;
+#endif
 	for (test = tests; test < tests + n_ents; test++) {
 		const char *name = test->name;
 
@@ -137,7 +146,7 @@ static int dm_test_main(const char *test_name)
 			name += 8;
 		if (test_name && strcmp(test_name, name))
 			continue;
-		ut_assertok(dm_do_test(uts, test));
+		ut_assertok(dm_do_test(uts, test, false));
 		run_count++;
 	}