Java树结构

前言

很久不写业务接口, 最近刚好遇到需要实现树结构, 做下简单总结

逻辑实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
package com.easyliao.demo;

import com.fasterxml.jackson.databind.ObjectMapper;
import lombok.Data;

import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

/**
* @author cuishiying
* @date 2021-01-22
*/
public class TreeTest {

@Data
public static class Node {
private Integer id;
private String name;
private Integer pid;
private List<Node> children;

public Node(Integer id, String name, Integer pid) {
this.id = id;
this.name = name;
this.pid = pid;
}
}

public static void main(String[] args) throws Exception{
Node p0 = new Node(1, "p00", 0);
Node p1 = new Node(2, "p01", 0);
Node p2 = new Node(3, "p10", 1);
Node p3 = new Node(4, "p11", 1);
Node p4 = new Node(5, "p20", 3);
List<Node> nodes = Arrays.asList(p0, p1, p2, p3, p4);
List<Node> tree = buildTree(nodes);

ObjectMapper objectMapper = new ObjectMapper();
System.out.println(objectMapper.writeValueAsString(tree));

}

/**
* 单个根节点
* @param pidList 平铺list
* @return 树结构
*/
public Node buildOneNodeTree(List<Node> pidList){
// pid -> children
Map<Integer,List<Node>> pidListMap = pidList.stream().collect(Collectors.groupingBy(Node::getPid));
pidList.forEach(item->item.setChildren(pidListMap.get(item.getId())));
// 取出顶层节点的对象,本例顶层节点的"PID"为0
return pidListMap.get(0).get(0);
}

/**
* 多个顶层节点
* @param pidList 平铺list
* @return 树结构
*/
public static List<Node> buildTree(List<Node> pidList){
// pid -> children
Map<Integer,List<Node>> pidListMap = pidList.stream().collect(Collectors.groupingBy(Node::getPid));
// 在内存中为同一对象, 所以set后分组也会生效
pidList.forEach(item->item.setChildren(pidListMap.get(item.getId())));
// 返回结果也改为返回顶层节点的list, 顶层节点的pid=0
return pidListMap.get(0);
}

}

简单抽象

我们这里做一个简单工具类

  1. 首先我们抽象一个接口, 所有实现Tree结构的对象需要实现该接口
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
package cn.idea360.assistant.dev.tree;

import java.util.List;

/**
* @author cuishiying
* @date 2021-01-22
*/
public interface TreeNode<T extends TreeNode<T, ID>, ID> {

/**
* 获取节点id
*
* @return 节点ID
*/
ID getId();

/**
* 获取该节点的父节点id
*
* @return 父节点ID
*/
ID getPid();

/**
* 获取当前节点所有子节点
*
* @return 子节点列表
*/
List<T> getChildren();

/**
* 设置当前节点的子节点列表
*
* @param children 子节点列表
*/
void setChildren(List<T> children);

}
  1. 定义基本实现类
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
package cn.idea360.assistant.dev.tree;

import lombok.Data;

import java.util.ArrayList;
import java.util.List;

/**
* @author cuishiying
* @date 2021-01-22
*/
@Data
public class Node implements TreeNode<Node, Long> {

private Long id;

private Long pid;

private transient List<Node> children = new ArrayList<>();

public Node(Long id, Long pid) {
this.id = id;
this.pid = pid;
}

}
  1. 工具类
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
package cn.idea360.assistant.dev.tree;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

/**
* @author cuishiying
* @date 2021-01-22
*/
public class TreeBuilder {

private TreeBuilder() {
}

/**
* 将list构建为tree结构
* @param nodes 节点list
* @param <T> 实现TreeNode接口的实体
* @param <ID> ID类型
* @return tree结构
*/
public static <T extends TreeNode<T, ID>, ID> List<T> buildTree(List<T> nodes) {
return new ArrayList<>(buildTreeMap(nodes).values());
}

/**
* 将list构建为tree结构
* @param nodes 节点list
* @param <T> 实现TreeNode接口的实体
* @param <ID> ID类型
* @return tree结构
*/
public static <T extends TreeNode<T, ID>, ID> Map<ID, T> buildTreeMap(List<T> nodes) {
Map<ID, T> nodeMap = new HashMap<>();
Map<ID, T> roots = new HashMap<>();

// Populate the map
for (T node : nodes) {
nodeMap.put(node.getId(), node);
}

// Build the tree
for (T node : nodes) {
ID pid = node.getPid();
if (pid == null || nodeMap.get(pid) == null) {
roots.put(node.getId(), node);
} else {
T parent = nodeMap.get(pid);
if (parent != null) {
parent.getChildren().add(node);
}
}
}

return roots;
}

/**
* 获取叶子节点(TreeNode::getPid不能为null)
* @param nodes 节点list
* @param pid 需要获取的叶子节点
* @param <T> 实现TreeNode接口的实体
* @param <ID> ID类型
* @return tree结构
*/
public static <T extends TreeNode<T, ID>, ID> List<T> getTreeNodes(List<T> nodes, ID pid) {
Map<ID, List<T>> pidMap = nodes.stream().collect(Collectors.groupingBy(TreeNode::getPid));
nodes.forEach(node -> node.setChildren(pidMap.get(node.getId())));
return pidMap.get(pid);
}

/**
* 根据叶子节点溯源
* @param nodes 节点list
* @param id 叶子节点
* @param <T> 实现TreeNode接口的实体
* @param <ID> ID类型
* @return 父节点
*/
public static <T extends TreeNode<T, ID>, ID> List<T> getParentNodes(List<T> nodes, ID id) {
List<T> result = new ArrayList<>();
Map<ID, T> nodeMap = nodes.stream().collect(Collectors.toMap(TreeNode::getId, Function.identity()));
T current = nodeMap.get(id);
while (current != null) {
result.add(current);
current = nodeMap.get(current.getPid());
}
// 按层级排序
Collections.reverse(result);
return result;
}
}
  1. 测试
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
public class TreeTest {

private final ObjectMapper objectMapper = new ObjectMapper();

private final List<Node> nodes = new ArrayList<>();

@Before
public void init() {
nodes.add(new Node(1L, 0L));
nodes.add(new Node(2L, 1L));
nodes.add(new Node(3L, 1L));
nodes.add(new Node(4L, 2L));
nodes.add(new Node(5L, 3L));
}

@SneakyThrows
@Test
public void t1() {
// 构建tree
List<Node> tree = buildTree(nodes);
System.out.println("tree结构:\n" + objectMapper.writerWithDefaultPrettyPrinter().writeValueAsString(tree));
}

@SneakyThrows
@Test
public void t2() {
// 获取父节点
List<Node> parentNodes = getParentNodes(nodes, 5L);
System.out.println("节点链路:\n" + objectMapper.writerWithDefaultPrettyPrinter().writeValueAsString(parentNodes));
}

@SneakyThrows
@Test
public void t3() {
// 获取叶子节点
List<Node> treeNodes = getTreeNodes(nodes, 2L);
System.out.println("叶子节点:\n" + objectMapper.writerWithDefaultPrettyPrinter().writeValueAsString(treeNodes));
}
}

输出

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
tree结构:
[ {
"id" : 1,
"pid" : 0,
"children" : [ {
"id" : 2,
"pid" : 1,
"children" : [ {
"id" : 4,
"pid" : 2,
"children" : [ ]
} ]
}, {
"id" : 3,
"pid" : 1,
"children" : [ {
"id" : 5,
"pid" : 3,
"children" : [ ]
} ]
} ]
} ]
节点链路:
[ {
"id" : 1,
"pid" : 0,
"children" : [ ]
}, {
"id" : 3,
"pid" : 1,
"children" : [ ]
}, {
"id" : 5,
"pid" : 3,
"children" : [ ]
} ]
叶子节点:
[ {
"id" : 4,
"pid" : 2,
"children" : null
} ]

最后

欢迎大家关注公众号【当我遇上你】支持我。