diff --git a/drivers/core/device-remove.c b/drivers/core/device-remove.c
index b80bf52320e2ded743fe202939b06b11d209486f..cc0043b990b77400f29374f06b48ca07e87c8ada 100644
--- a/drivers/core/device-remove.c
+++ b/drivers/core/device-remove.c
@@ -174,7 +174,13 @@ int device_remove(struct udevice *dev, uint flags)
 	if (ret)
 		goto err;
 
-	if (drv->remove) {
+	/*
+	 * Remove the device if called with the "normal" remove flag set,
+	 * or if the remove flag matches any of the drivers remove flags
+	 */
+	if (drv->remove &&
+	    ((flags & DM_REMOVE_NORMAL) ||
+	     (flags & (drv->flags & DM_FLAG_ACTIVE_DMA)))) {
 		ret = drv->remove(dev);
 		if (ret)
 			goto err_remove;
@@ -188,10 +194,13 @@ int device_remove(struct udevice *dev, uint flags)
 		}
 	}
 
-	device_free(dev);
+	if ((flags & DM_REMOVE_NORMAL) ||
+	    (flags & (drv->flags & DM_FLAG_ACTIVE_DMA))) {
+		device_free(dev);
 
-	dev->seq = -1;
-	dev->flags &= ~DM_FLAG_ACTIVATED;
+		dev->seq = -1;
+		dev->flags &= ~DM_FLAG_ACTIVATED;
+	}
 
 	return ret;
 
diff --git a/drivers/core/root.c b/drivers/core/root.c
index d8c51fb496f10d2c6f7a2412093160b286da86d1..42679d047cfa0522988576fc10b360c04e972ce3 100644
--- a/drivers/core/root.c
+++ b/drivers/core/root.c
@@ -184,6 +184,15 @@ int dm_uninit(void)
 	return 0;
 }
 
+#if CONFIG_IS_ENABLED(DM_DEVICE_REMOVE)
+int dm_remove_devices_flags(uint flags)
+{
+	device_remove(dm_root(), flags);
+
+	return 0;
+}
+#endif
+
 int dm_scan_platdata(bool pre_reloc_only)
 {
 	int ret;
diff --git a/include/dm/root.h b/include/dm/root.h
index 3cf730dcee1cb50e9aa5ba3376ae48fc97603cdb..058eb9892314ebf3275997d16221e25d3e55b4c1 100644
--- a/include/dm/root.h
+++ b/include/dm/root.h
@@ -115,4 +115,20 @@ int dm_init(void);
  */
 int dm_uninit(void);
 
+#if CONFIG_IS_ENABLED(DM_DEVICE_REMOVE)
+/**
+ * dm_remove_devices_flags - Call remove function of all drivers with
+ *                           specific removal flags set to selectively
+ *                           remove drivers
+ *
+ * All devices with the matching flags set will be removed
+ *
+ * @flags: Flags for selective device removal
+ * @return 0 if OK, -ve on error
+ */
+int dm_remove_devices_flags(uint flags);
+#else
+static inline int dm_remove_devices_flags(uint flags) { return 0; }
+#endif
+
 #endif