863 All nodes in distance k in binary tree

https://leetcode.com/problems/all-nodes-distance-k-in-binary-tree/

最直接的想法就是重新建图然后从target开始做bfs:

/**
 * Definition for a binary tree node.
 * public class TreeNode {
 *     int val;
 *     TreeNode left;
 *     TreeNode right;
 *     TreeNode(int x) { val = x; }
 * }
 */
class Solution {
    List<Integer> ans = new ArrayList<>();
    public List<Integer> distanceK(TreeNode root, TreeNode target, int k) {
        Map<TreeNode, Set<TreeNode>> map = new HashMap<>();
        Queue<TreeNode> queue = new ArrayDeque<>();
        queue.offer(root);
        while (!queue.isEmpty()) {
            TreeNode node = queue.poll();
            Set<TreeNode> set = map.get(node);
            if (set == null) set = new HashSet<>();
            if (node.left != null) {
                set.add(node.left);
                queue.offer(node.left);
                Set<TreeNode> cset = map.get(node.left);
                if (cset == null) cset = new HashSet<>();
                cset.add(node);
                map.put(node.left, cset);
            }
            if (node.right != null) {
                set.add(node.right);
                queue.offer(node.right);
                Set<TreeNode> cset = map.get(node.right);
                if (cset == null) cset = new HashSet<>();
                cset.add(node);
                map.put(node.right, cset);
            }
            map.put(node, set);
        }
        int len = 0;
        queue.offer(target);
        Set<TreeNode> visited = new HashSet<>();
        while (!queue.isEmpty()) {
            int size = queue.size();
            for (int s = 0; s < size; s++) {
                TreeNode node = queue.poll();
                if (visited.contains(node)) continue;
                visited.add(node);
                if (len == k) ans.add(node.val);
                for (TreeNode next : map.get(node)) {
                    queue.offer(next);
                }
            }
            len++;
            if (len > k) break;
        }
        return ans;
    }
}

或者做dfs,注意这种把树看成图的思想,在需要往上走的时候很有用

/**
 * Definition for a binary tree node.
 * public class TreeNode {
 *     int val;
 *     TreeNode left;
 *     TreeNode right;
 *     TreeNode(int x) { val = x; }
 * }
 */
class Solution {
    class Node {
        Node left;
        Node right;
        Node parent;
        int val;
        Node(int x) {
            val = x;
        }
    }
    Node tnode = null;
    Node r = null;
    List<Integer> ans = new ArrayList<>();
    Set<Node> visited = new HashSet<>();
    public List<Integer> distanceK(TreeNode root, TreeNode target, int k) {
        if (k == 0) {
            ans.add(target.val);
            return ans;
        }
        r = clone(root, null, target);
        dfs(tnode, k);
        return ans;
    }
    
    private void dfs(Node node, int k) {
        if (node == null) {
            return;
        }
        if (visited.contains(node)) {
            return;
        }
        visited.add(node);
        if (k == 0) {
            ans.add(node.val);
        }
        dfs(node.left, k-1);
        dfs(node.right, k-1);
        dfs(node.parent, k-1);
    }

    private Node clone(TreeNode node, Node parent, TreeNode target) {
        if (node == null) {
            return null;
        }
        Node n = new Node(node.val);
        if (node == target) {
            tnode = n;
        }
        n.parent = parent;
        n.left = clone(node.left, n, target);
        n.right = clone(node.right, n, target);
        return n;
    }
}

Last updated