JAVA-多线程3

分類 编程语言, Java

Java语言虽然内置了多线程支持,启动一个新线程非常方便,但是,创建线程需要操作系统资源(线程资源,栈空间等),频繁创建和销毁大量线程需要消耗大量时间。

如果可以复用一组线程:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
┌─────┐ execute  ┌──────────────────┐
│Task1│─────────▶│ThreadPool │
├─────┤ │┌───────┐┌───────┐│
│Task2│ ││Thread1││Thread2││
├─────┤ │└───────┘└───────┘│
│Task3│ │┌───────┐┌───────┐│
├─────┤ ││Thread3││Thread4││
│Task4│ │└───────┘└───────┘│
├─────┤ └──────────────────┘
│Task5│
├─────┤
│Task6│
└─────┘
...

那么我们就可以把很多小任务让一组线程来执行,而不是一个任务对应一个新线程。这种能接收大量小任务并进行分发处理的就是线程池。

简单地说,线程池内部维护了若干个线程,没有任务的时候,这些线程都处于等待状态。如果有新任务,就分配一个空闲线程执行。如果所有线程都处于忙碌状态,新任务要么放入队列等待,要么增加一个新线程进行处理。

Java标准库提供了ExecutorService接口表示线程池,它的典型用法如下:

1
2
3
4
5
6
7
8
// 创建固定大小的线程池:
ExecutorService executor = Executors.newFixedThreadPool(3);
// 提交任务:
executor.submit(task1);
executor.submit(task2);
executor.submit(task3);
executor.submit(task4);
executor.submit(task5);

因为ExecutorService只是接口,Java标准库提供的几个常用实现类有:

  • FixedThreadPool:线程数固定的线程池;
  • CachedThreadPool:线程数根据任务动态调整的线程池;
  • SingleThreadExecutor:仅单线程执行的线程池。

创建这些线程池的方法都被封装到Executors这个类中。我们以FixedThreadPool为例,看看线程池的执行逻辑:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
// thread-pool
import java.util.concurrent.*;

public class Main {
public static void main(String[] args) {
// 创建一个固定大小的线程池:
ExecutorService es = Executors.newFixedThreadPool(4);
for (int i = 0; i < 6; i++) {
es.submit(new Task("" + i));
}
// 关闭线程池:
es.shutdown();
}
}

class Task implements Runnable {
private final String name;

public Task(String name) {
this.name = name;
}

@Override
public void run() {
System.out.println("start task " + name);
try {
Thread.sleep(1000);
} catch (InterruptedException e) {
}
System.out.println("end task " + name);
}
}

我们观察执行结果,一次性放入6个任务,由于线程池只有固定的4个线程,因此,前4个任务会同时执行,等到有线程空闲后,才会执行后面的两个任务。

线程池在程序结束的时候要关闭。使用shutdown()方法关闭线程池的时候,它会等待正在执行的任务先完成,然后再关闭。shutdownNow()会立刻停止正在执行的任务,awaitTermination()则会等待指定的时间让线程池关闭。

如果我们把线程池改为CachedThreadPool,由于这个线程池的实现会根据任务数量动态调整线程池的大小,所以6个任务可一次性全部同时执行。

如果我们想把线程池的大小限制在4~10个之间动态调整怎么办?我们查看Executors.newCachedThreadPool()方法的源码:

1
2
3
4
5
6
public static ExecutorService newCachedThreadPool() {
return new ThreadPoolExecutor(
0, Integer.MAX_VALUE,
60L, TimeUnit.SECONDS,
new SynchronousQueue<Runnable>());
}

因此,想创建指定动态范围的线程池,可以这么写:

1
2
3
4
5
6
int min = 4;
int max = 10;
ExecutorService es = new ThreadPoolExecutor(
min, max,
60L, TimeUnit.SECONDS,
new SynchronousQueue<Runnable>());

ScheduledThreadPool

还有一种任务,需要定期反复执行,例如,每秒刷新证券价格。这种任务本身固定,需要反复执行的,可以使用ScheduledThreadPool。放入ScheduledThreadPool的任务可以定期反复执行。

创建一个ScheduledThreadPool仍然是通过Executors类:

1
ScheduledExecutorService ses = Executors.newScheduledThreadPool(4);

我们可以提交一次性任务,它会在指定延迟后只执行一次:

1
2
// 1秒后执行一次性任务:
ses.schedule(new Task("one-time"), 1, TimeUnit.SECONDS);

如果任务以固定的每3秒执行,我们可以这样写:

1
2
// 2秒后开始执行定时任务,每3秒执行:
ses.scheduleAtFixedRate(new Task("fixed-rate"), 2, 3, TimeUnit.SECONDS);

如果任务以固定的3秒为间隔执行,我们可以这样写:

1
2
// 2秒后开始执行定时任务,以3秒为间隔执行:
ses.scheduleWithFixedDelay(new Task("fixed-delay"), 2, 3, TimeUnit.SECONDS);

注意FixedRate和FixedDelay的区别。FixedRate是指任务总是以固定时间间隔触发,不管任务执行多长时间:

1
2
3
│░░░░   │░░░░░░ │░░░    │░░░░░  │░░░  
├───────┼───────┼───────┼───────┼────▶
│◀─────▶│◀─────▶│◀─────▶│◀─────▶│

而FixedDelay是指,上一次任务执行完毕后,等待固定的时间间隔,再执行下一次任务:

1
2
3
│░░░│       │░░░░░│       │░░│       │░
└───┼───────┼─────┼───────┼──┼───────┼──▶
│◀─────▶│ │◀─────▶│ │◀─────▶│

因此,使用ScheduledThreadPool时,我们要根据需要选择执行一次、FixedRate执行还是FixedDelay执行。

细心的童鞋还可以思考下面的问题:

  • 在FixedRate模式下,假设每秒触发,如果某次任务执行时间超过1秒,后续任务会不会并发执行?
  • 如果任务抛出了异常,后续任务是否继续执行?

Java标准库还提供了一个java.util.Timer类,这个类也可以定期执行任务,但是,一个Timer会对应一个Thread,所以,一个Timer只能定期执行一个任务,多个定时任务必须启动多个Timer,而一个ScheduledThreadPool就可以调度多个定时任务,所以,我们完全可以用ScheduledThreadPool取代旧的Timer

练习

使用线程池复用线程。

下载练习

小结

JDK提供了ExecutorService实现了线程池功能:

  • 线程池内部维护一组线程,可以高效执行大量小任务;
  • Executors提供了静态方法创建不同类型的ExecutorService
  • 必须调用shutdown()关闭ExecutorService
  • ScheduledThreadPool可以定期调度多个任务。

使用Future

在执行多个任务的时候,使用Java标准库提供的线程池是非常方便的。我们提交的任务只需要实现Runnable接口,就可以让线程池去执行:

1
2
3
4
5
6
7
class Task implements Runnable {
public String result;

public void run() {
this.result = longTimeCalculation();
}
}

Runnable接口有个问题,它的方法没有返回值。如果任务需要一个返回结果,那么只能保存到变量,还要提供额外的方法读取,非常不便。所以,Java标准库还提供了一个Callable接口,和Runnable接口比,它多了一个返回值:

1
2
3
4
5
class Task implements Callable<String> {
public String call() throws Exception {
return longTimeCalculation();
}
}

并且Callable接口是一个泛型接口,可以返回指定类型的结果。

现在的问题是,如何获得异步执行的结果?

如果仔细看ExecutorService.submit()方法,可以看到,它返回了一个Future类型,一个Future类型的实例代表一个未来能获取结果的对象:

1
2
3
4
5
6
7
ExecutorService executor = Executors.newFixedThreadPool(4); 
// 定义任务:
Callable<String> task = new Task();
// 提交任务并获得Future:
Future<String> future = executor.submit(task);
// 从Future获取异步执行返回的结果:
String result = future.get(); // 可能阻塞

当我们提交一个Callable任务后,我们会同时获得一个Future对象,然后,我们在主线程某个时刻调用Future对象的get()方法,就可以获得异步执行的结果。在调用get()时,如果异步任务已经完成,我们就直接获得结果。如果异步任务还没有完成,那么get()会阻塞,直到任务完成后才返回结果。

一个Future<V>接口表示一个未来可能会返回的结果,它定义的方法有:

  • get():获取结果(可能会等待)
  • get(long timeout, TimeUnit unit):获取结果,但只等待指定的时间;
  • cancel(boolean mayInterruptIfRunning):取消当前任务;
  • isDone():判断任务是否已完成。

练习

使用Future获取异步执行结果。

下载练习

小结

对线程池提交一个Callable任务,可以获得一个Future对象;

可以用Future在将来某个时刻获取结果。



使用Future获得异步执行结果时,要么调用阻塞方法get(),要么轮询看isDone()是否为true,这两种方法都不是很好,因为主线程也会被迫等待。

从Java 8开始引入了CompletableFuture,它针对Future做了改进,可以传入回调对象,当异步任务完成或者发生异常时,自动调用回调对象的回调方法。

我们以获取股票价格为例,看看如何使用CompletableFuture

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
// CompletableFuture
import java.util.concurrent.CompletableFuture;

public class Main {
public static void main(String[] args) throws Exception {
// 创建异步执行任务:
CompletableFuture<Double> cf = CompletableFuture.supplyAsync(Main::fetchPrice);
// 如果执行成功:
cf.thenAccept((result) -> {
System.out.println("price: " + result);
});
// 如果执行异常:
cf.exceptionally((e) -> {
e.printStackTrace();
return null;
});
// 主线程不要立刻结束,否则CompletableFuture默认使用的线程池会立刻关闭:
Thread.sleep(200);
}

static Double fetchPrice() {
try {
Thread.sleep(100);
} catch (InterruptedException e) {
}
if (Math.random() < 0.3) {
throw new RuntimeException("fetch price failed!");
}
return 5 + Math.random() * 20;
}
}

创建一个CompletableFuture是通过CompletableFuture.supplyAsync()实现的,它需要一个实现了Supplier接口的对象:

1
2
3
public interface Supplier<T> {
T get();
}

这里我们用lambda语法简化了一下,直接传入Main::fetchPrice,因为Main.fetchPrice()静态方法的签名符合Supplier接口的定义(除了方法名外)。

紧接着,CompletableFuture已经被提交给默认的线程池执行了,我们需要定义的是CompletableFuture完成时和异常时需要回调的实例。完成时,CompletableFuture会调用Consumer对象:

1
2
3
public interface Consumer<T> {
void accept(T t);
}

异常时,CompletableFuture会调用Function对象:

1
2
3
public interface Function<T, R> {
R apply(T t);
}

这里我们都用lambda语法简化了代码。

可见CompletableFuture的优点是:

  • 异步任务结束时,会自动回调某个对象的方法;
  • 异步任务出错时,会自动回调某个对象的方法;
  • 主线程设置好回调后,不再关心异步任务的执行。

如果只是实现了异步回调机制,我们还看不出CompletableFuture相比Future的优势。CompletableFuture更强大的功能是,多个CompletableFuture可以串行执行,例如,定义两个CompletableFuture,第一个CompletableFuture根据证券名称查询证券代码,第二个CompletableFuture根据证券代码查询证券价格,这两个CompletableFuture实现串行操作如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
// CompletableFuture
import java.util.concurrent.CompletableFuture;

public class Main {
public static void main(String[] args) throws Exception {
// 第一个任务:
CompletableFuture<String> cfQuery = CompletableFuture.supplyAsync(() -> {
return queryCode("中国石油");
});
// cfQuery成功后继续执行下一个任务:
CompletableFuture<Double> cfFetch = cfQuery.thenApplyAsync((code) -> {
return fetchPrice(code);
});
// cfFetch成功后打印结果:
cfFetch.thenAccept((result) -> {
System.out.println("price: " + result);
});
// 主线程不要立刻结束,否则CompletableFuture默认使用的线程池会立刻关闭:
Thread.sleep(2000);
}

static String queryCode(String name) {
try {
Thread.sleep(100);
} catch (InterruptedException e) {
}
return "601857";
}

static Double fetchPrice(String code) {
try {
Thread.sleep(100);
} catch (InterruptedException e) {
}
return 5 + Math.random() * 20;
}
}

除了串行执行外,多个CompletableFuture还可以并行执行。例如,我们考虑这样的场景:

同时从新浪和网易查询证券代码,只要任意一个返回结果,就进行下一步查询价格,查询价格也同时从新浪和网易查询,只要任意一个返回结果,就完成操作:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
// CompletableFuture
import java.util.concurrent.CompletableFuture;

public class Main {
public static void main(String[] args) throws Exception {
// 两个CompletableFuture执行异步查询:
CompletableFuture<String> cfQueryFromSina = CompletableFuture.supplyAsync(() -> {
return queryCode("中国石油", "https://finance.sina.com.cn/code/");
});
CompletableFuture<String> cfQueryFrom163 = CompletableFuture.supplyAsync(() -> {
return queryCode("中国石油", "https://money.163.com/code/");
});

// 用anyOf合并为一个新的CompletableFuture:
CompletableFuture<Object> cfQuery = CompletableFuture.anyOf(cfQueryFromSina, cfQueryFrom163);

// 两个CompletableFuture执行异步查询:
CompletableFuture<Double> cfFetchFromSina = cfQuery.thenApplyAsync((code) -> {
return fetchPrice((String) code, "https://finance.sina.com.cn/price/");
});
CompletableFuture<Double> cfFetchFrom163 = cfQuery.thenApplyAsync((code) -> {
return fetchPrice((String) code, "https://money.163.com/price/");
});

// 用anyOf合并为一个新的CompletableFuture:
CompletableFuture<Object> cfFetch = CompletableFuture.anyOf(cfFetchFromSina, cfFetchFrom163);

// 最终结果:
cfFetch.thenAccept((result) -> {
System.out.println("price: " + result);
});
// 主线程不要立刻结束,否则CompletableFuture默认使用的线程池会立刻关闭:
Thread.sleep(200);
}

static String queryCode(String name, String url) {
System.out.println("query code from " + url + "...");
try {
Thread.sleep((long) (Math.random() * 100));
} catch (InterruptedException e) {
}
return "601857";
}

static Double fetchPrice(String code, String url) {
System.out.println("query price from " + url + "...");
try {
Thread.sleep((long) (Math.random() * 100));
} catch (InterruptedException e) {
}
return 5 + Math.random() * 20;
}
}

上述逻辑实现的异步查询规则实际上是:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
┌─────────────┐ ┌─────────────┐
│ Query Code │ │ Query Code │
│ from sina │ │ from 163 │
└─────────────┘ └─────────────┘
│ │
└───────┬───────┘

┌─────────────┐
│ anyOf │
└─────────────┘

┌───────┴────────┐
▼ ▼
┌─────────────┐ ┌─────────────┐
│ Query Price │ │ Query Price │
│ from sina │ │ from 163 │
└─────────────┘ └─────────────┘
│ │
└────────┬───────┘

┌─────────────┐
│ anyOf │
└─────────────┘


┌─────────────┐
│Display Price│
└─────────────┘

除了anyOf()可以实现“任意个CompletableFuture只要一个成功”,allOf()可以实现“所有CompletableFuture都必须成功”,这些组合操作可以实现非常复杂的异步流程控制。

最后我们注意CompletableFuture的命名规则:

  • xxx():表示该方法将继续在已有的线程中执行;
  • xxxAsync():表示将异步在线程池中执行。

练习

使用CompletableFuture。

下载练习

小结

CompletableFuture可以指定异步处理流程:

  • thenAccept()处理正常结果;
  • exceptional()处理异常结果;
  • thenApplyAsync()用于串行化另一个CompletableFuture
  • anyOf()allOf()用于并行化多个CompletableFuture

Java 7开始引入了一种新的Fork/Join线程池,它可以执行一种特殊的任务:把一个大任务拆成多个小任务并行执行。

我们举个例子:如果要计算一个超大数组的和,最简单的做法是用一个循环在一个线程内完成:

1
2
┌─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┐
└─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┘

还有一种方法,可以把数组拆成两部分,分别计算,最后加起来就是最终结果,这样可以用两个线程并行执行:

1
2
3
4
┌─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┐
└─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┘
┌─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┐
└─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┘

如果拆成两部分还是很大,我们还可以继续拆,用4个线程并行执行:

1
2
3
4
5
6
7
8
┌─┬─┬─┬─┬─┬─┐
└─┴─┴─┴─┴─┴─┘
┌─┬─┬─┬─┬─┬─┐
└─┴─┴─┴─┴─┴─┘
┌─┬─┬─┬─┬─┬─┐
└─┴─┴─┴─┴─┴─┘
┌─┬─┬─┬─┬─┬─┐
└─┴─┴─┴─┴─┴─┘

这就是Fork/Join任务的原理:判断一个任务是否足够小,如果是,直接计算,否则,就分拆成几个小任务分别计算。这个过程可以反复“裂变”成一系列小任务。

我们来看如何使用Fork/Join对大数据进行并行求和:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import java.util.Random;
import java.util.concurrent.*;

public class Main {
public static void main(String[] args) throws Exception {
// 创建2000个随机数组成的数组:
long[] array = new long[2000];
long expectedSum = 0;
for (int i = 0; i < array.length; i++) {
array[i] = random();
expectedSum += array[i];
}
System.out.println("Expected sum: " + expectedSum);
// fork/join:
ForkJoinTask<Long> task = new SumTask(array, 0, array.length);
long startTime = System.currentTimeMillis();
Long result = ForkJoinPool.commonPool().invoke(task);
long endTime = System.currentTimeMillis();
System.out.println("Fork/join sum: " + result + " in " + (endTime - startTime) + " ms.");
}

static Random random = new Random(0);

static long random() {
return random.nextInt(10000);
}
}

class SumTask extends RecursiveTask<Long> {
static final int THRESHOLD = 500;
long[] array;
int start;
int end;

SumTask(long[] array, int start, int end) {
this.array = array;
this.start = start;
this.end = end;
}

@Override
protected Long compute() {
if (end - start <= THRESHOLD) {
// 如果任务足够小,直接计算:
long sum = 0;
for (int i = start; i < end; i++) {
sum += this.array[i];
// 故意放慢计算速度:
try {
Thread.sleep(1);
} catch (InterruptedException e) {
}
}
return sum;
}
// 任务太大,一分为二:
int middle = (end + start) / 2;
System.out.println(String.format("split %d~%d ==> %d~%d, %d~%d", start, end, start, middle, middle, end));
SumTask subtask1 = new SumTask(this.array, start, middle);
SumTask subtask2 = new SumTask(this.array, middle, end);
invokeAll(subtask1, subtask2);
Long subresult1 = subtask1.join();
Long subresult2 = subtask2.join();
Long result = subresult1 + subresult2;
System.out.println("result = " + subresult1 + " + " + subresult2 + " ==> " + result);
return result;
}
}

观察上述代码的执行过程,一个大的计算任务0~2000首先分裂为两个小任务0~1000和1000~2000,这两个小任务仍然太大,继续分裂为更小的0~500,500~1000,1000~1500,1500~2000,最后,计算结果被依次合并,得到最终结果。

因此,核心代码SumTask继承自RecursiveTask,在compute()方法中,关键是如何“分裂”出子任务并且提交子任务:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
class SumTask extends RecursiveTask<Long> {
protected Long compute() {
// “分裂”子任务:
SumTask subtask1 = new SumTask(...);
SumTask subtask2 = new SumTask(...);
// invokeAll会并行运行两个子任务:
invokeAll(subtask1, subtask2);
// 获得子任务的结果:
Long subresult1 = subtask1.join();
Long subresult2 = subtask2.join();
// 汇总结果:
return subresult1 + subresult2;
}
}

Fork/Join线程池在Java标准库中就有应用。Java标准库提供的java.util.Arrays.parallelSort(array)可以进行并行排序,它的原理就是内部通过Fork/Join对大数组分拆进行并行排序,在多核CPU上就可以大大提高排序的速度。

练习

使用Fork/Join。

下载练习

小结

Fork/Join是一种基于“分治”的算法:通过分解任务,并行执行,最后合并结果得到最终结果。

ForkJoinPool线程池可以把一个大任务分拆成小任务并行执行,任务类必须继承自RecursiveTaskRecursiveAction

使用Fork/Join模式可以进行并行计算以提高效率。

多线程是Java实现多任务的基础,Thread对象代表一个线程,我们可以在代码中调用Thread.currentThread()获取当前线程。例如,打印日志时,可以同时打印出当前线程的名字:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
// Thread
public class Main {
public static void main(String[] args) throws Exception {
log("start main...");
new Thread(() -> {
log("run task...");
}).start();
new Thread(() -> {
log("print...");
}).start();
log("end main.");
}

static void log(String s) {
System.out.println(Thread.currentThread().getName() + ": " + s);
}
}

对于多任务,Java标准库提供的线程池可以方便地执行这些任务,同时复用线程。Web应用程序就是典型的多任务应用,每个用户请求页面时,我们都会创建一个任务,类似:

1
2
3
4
5
6
public void process(User user) {
checkPermission();
doWork();
saveStatus();
sendResponse();
}

然后,通过线程池去执行这些任务。

观察process()方法,它内部需要调用若干其他方法,同时,我们遇到一个问题:如何在一个线程内传递状态?

process()方法需要传递的状态就是User实例。有的童鞋会想,简单地传入User就可以了:

1
2
3
4
5
6
public void process(User user) {
checkPermission(user);
doWork(user);
saveStatus(user);
sendResponse(user);
}

但是往往一个方法又会调用其他很多方法,这样会导致User传递到所有地方:

1
2
3
4
5
6
void doWork(User user) {
queryStatus(user);
checkStatus();
setNewStatus(user);
log();
}

这种在一个线程中,横跨若干方法调用,需要传递的对象,我们通常称之为上下文(Context),它是一种状态,可以是用户身份、任务信息等。

给每个方法增加一个context参数非常麻烦,而且有些时候,如果调用链有无法修改源码的第三方库,User对象就传不进去了。

Java标准库提供了一个特殊的ThreadLocal,它可以在一个线程中传递同一个对象。

ThreadLocal实例通常总是以静态字段初始化如下:

1
static ThreadLocal<User> threadLocalUser = new ThreadLocal<>();

它的典型使用方式如下:

1
2
3
4
5
6
7
8
9
10
void processUser(user) {
try {
threadLocalUser.set(user);
step1();
step2();
log();
} finally {
threadLocalUser.remove();
}
}

通过设置一个User实例关联到ThreadLocal中,在移除之前,所有方法都可以随时获取到该User实例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
void step1() {
User u = threadLocalUser.get();
log();
printUser();
}

void step2() {
User u = threadLocalUser.get();
checkUser(u.id);
}

void log() {
User u = threadLocalUser.get();
println(u.name);
}

注意到普通的方法调用一定是同一个线程执行的,所以,step1()step2()以及log()方法内,threadLocalUser.get()获取的User对象是同一个实例。

实际上,可以把ThreadLocal看成一个全局Map<Thread, Object>:每个线程获取ThreadLocal变量时,总是使用Thread自身作为key:

1
Object threadLocalValue = threadLocalMap.get(Thread.currentThread());

因此,ThreadLocal相当于给每个线程都开辟了一个独立的存储空间,各个线程的ThreadLocal关联的实例互不干扰。

最后,特别注意ThreadLocal一定要在finally中清除:

1
2
3
4
5
6
try {
threadLocalUser.set(user);
...
} finally {
threadLocalUser.remove();
}

这是因为当前线程执行完相关代码后,很可能会被重新放入线程池中,如果ThreadLocal没有被清除,该线程执行其他代码时,会把上一次的状态带进去。

为了保证能释放ThreadLocal关联的实例,我们可以通过AutoCloseable接口配合try (resource) {...}结构,让编译器自动为我们关闭。例如,一个保存了当前用户名的ThreadLocal可以封装为一个UserContext对象:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
public class UserContext implements AutoCloseable {

static final ThreadLocal<String> ctx = new ThreadLocal<>();

public UserContext(String user) {
ctx.set(user);
}

public static String currentUser() {
return ctx.get();
}

@Override
public void close() {
ctx.remove();
}
}

使用的时候,我们借助try (resource) {...}结构,可以这么写:

1
2
3
4
try (var ctx = new UserContext("Bob")) {
// 可任意调用UserContext.currentUser():
String currentUser = UserContext.currentUser();
} // 在此自动调用UserContext.close()方法释放ThreadLocal关联对象

这样就在UserContext中完全封装了ThreadLocal,外部代码在try (resource) {...}内部可以随时调用UserContext.currentUser()获取当前线程绑定的用户名。

练习

练习使用ThreadLocal。

下载练习

小结

ThreadLocal表示线程的“局部变量”,它确保每个线程的ThreadLocal变量都是各自独立的;

ThreadLocal适合在一个线程的处理流程中保持上下文(避免了同一参数在所有方法中传递);

使用ThreadLocal要用try ... finally结构,并在finally中清除。

虚拟线程(Virtual Thread)是Java 19引入的一种轻量级线程,它在很多其他语言中被称为协程、纤程、绿色线程、用户态线程等。

在理解虚拟线程前,我们先回顾一下线程的特点:

  • 线程是由操作系统创建并调度的资源;
  • 线程切换会耗费大量CPU时间;
  • 一个系统能同时调度的线程数量是有限的,通常在几百至几千级别。

因此,我们说线程是一种重量级资源。在服务器端,对用户请求,通常都实现为一个线程处理一个请求。由于用户的请求数往往远超操作系统能同时调度的线程数量,所以通常使用线程池来尽量减少频繁创建和销毁线程的成本。

对于需要处理大量IO请求的任务来说,使用线程是低效的,因为一旦读写IO,线程就必须进入等待状态,直到IO数据返回。常见的IO操作包括:

  • 读写文件;
  • 读写网络,例如HTTP请求;
  • 读写数据库,本质上是通过JDBC实现网络调用。

我们举个例子,一个处理HTTP请求的线程,它在读写网络、文件的时候就会进入等待状态:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
Begin
────────
Blocking ──▶ Read HTTP Request
Wait...
Wait...
Wait...
────────
Running
────────
Blocking ──▶ Read Config File
Wait...
────────
Running
────────
Blocking ──▶ Read Database
Wait...
Wait...
Wait...
────────
Running
────────
Blocking ──▶ Send HTTP Response
Wait...
Wait...
────────
End

真正由CPU执行的代码消耗的时间非常少,线程的大部分时间都在等待IO。我们把这类任务称为IO密集型任务。

为了能高效执行IO密集型任务,Java从19开始引入了虚拟线程。虚拟线程的接口和普通线程是一样的,但是执行方式不一样。虚拟线程不是由操作系统调度,而是由普通线程调度,即成百上千个虚拟线程可以由一个普通线程调度。任何时刻,只能执行一个虚拟线程,但是,一旦该虚拟线程执行一个IO操作进入等待时,它会被立刻“挂起”,然后执行下一个虚拟线程。什么时候IO数据返回了,这个挂起的虚拟线程才会被再次调度。因此,若干个虚拟线程可以在一个普通线程中交替运行:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
Begin
───────────
V1 Runing
V1 Blocking ──▶ Read HTTP Request
───────────
V2 Runing
V2 Blocking ──▶ Read HTTP Request
───────────
V3 Runing
V3 Blocking ──▶ Read HTTP Request
───────────
V1 Runing
V1 Blocking ──▶ Read Config File
───────────
V2 Runing
V2 Blocking ──▶ Read Database
───────────
V1 Runing
V1 Blocking ──▶ Read Database
───────────
V3 Runing
V3 Blocking ──▶ Read Database
───────────
V2 Runing
V2 Blocking ──▶ Send HTTP Response
───────────
V1 Runing
V1 Blocking ──▶ Send HTTP Response
───────────
V3 Runing
V3 Blocking ──▶ Send HTTP Response
───────────
End

如果我们单独看一个虚拟线程的代码,在一个方法中:

1
2
3
4
5
6
7
8
9
10
void register() {
config = readConfigFile("./config.json"); // #1
if (config.useFullName) {
name = req.firstName + " " + req.lastName;
}
insertInto(db, name); // #2
if (config.cache) {
redis.set(key, name); // #3
}
}

涉及到IO读写的#1、#2、#3处,执行到这些地方的时候(进入相关的JNI方法内部时)会自动挂起,并切换到其他虚拟线程执行。等到数据返回后,当前虚拟线程会再次调度并执行,因此,代码看起来是同步执行,但实际上是异步执行的。

使用虚拟线程

虚拟线程的接口和普通线程一样,唯一区别在于创建虚拟线程只能通过特定方法。

方法一:直接创建虚拟线程并运行:

1
2
3
4
5
6
// 传入Runnable实例并立刻运行:
Thread vt = Thread.startVirtualThread(() -> {
System.out.println("Start virtual thread...");
Thread.sleep(10);
System.out.println("End virtual thread.");
});

方法二:创建虚拟线程但不自动运行,而是手动调用start()开始运行:

1
2
3
4
5
6
7
8
// 创建VirtualThread:
Thread.ofVirtual().unstarted(() -> {
System.out.println("Start virtual thread...");
Thread.sleep(1000);
System.out.println("End virtual thread.");
});
// 运行:
vt.start();

方法三:通过虚拟线程的ThreadFactory创建虚拟线程,然后手动调用start()开始运行:

1
2
3
4
5
6
7
8
9
10
// 创建ThreadFactory:
ThreadFactory tf = Thread.ofVirtual().factory();
// 创建VirtualThread:
Thread vt = tf.newThread(() -> {
System.out.println("Start virtual thread...");
Thread.sleep(1000);
System.out.println("End virtual thread.");
});
// 运行:
vt.start();

直接调用start()实际上是由ForkJoinPool的线程来调度的。我们也可以自己创建调度线程,然后运行虚拟线程:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// 创建调度器:
ExecutorService executor = Executors.newVirtualThreadPerTaskExecutor();
// 创建大量虚拟线程并调度:
ThreadFactory tf = Thread.ofVirtual().factory();
for (int i=0; i<100000; i++) {
Thread vt = tf.newThread(() -> { ... });
executor.submit(vt);
// 也可以直接传入Runnable或Callable:
executor.submit(() -> {
System.out.println("Start virtual thread...");
Thread.sleep(1000);
System.out.println("End virtual thread.");
return true;
});
}

由于虚拟线程属于非常轻量级的资源,因此,用时创建,用完就扔,不要池化虚拟线程。

最后注意,虚拟线程在Java 21正式发布,在Java 19/20是预览功能,默认关闭,需要添加参数--enable-preview启用:

1
java --source 19 --enable-preview Main.java

使用限制

注意到只有以虚拟线程方式运行的代码,才会在执行IO操作时自动被挂起并切换到其他虚拟线程。普通线程的IO操作仍然会等待,例如,我们在main()方法中读写文件,是不会有调度和自动挂起的。

可以自动引发调度切换的操作包括:

  • 文件IO;
  • 网络IO;
  • 使用Concurrent库引发等待;
  • Thread.sleep()操作。

这是因为JDK为了实现虚拟线程,已经对底层相关操作进行了修改,这样应用层的Java代码无需修改即可使用虚拟线程。无法自动切换的语言需要用户手动调用await来实现异步操作:

1
2
3
4
async function doWork() {
await readFile();
await sendNetworkData();
}

在虚拟线程中,如果绕过JDK的IO接口,直接通过JNI读写文件或网络是无法实现调度的。此外,在synchronized块内部也无法调度。

练习

使用虚拟线程调度10万个任务并观察耗时:

1
2
3
4
5
6
7
8
9
10
11
12
public class Main {
public static void main(String[] args) {
ExecutorService es = Executors.newVirtualThreadPerTaskExecutor();
for (int i=0; i<100000; i++) {
es.submit(() -> {
Thread.sleep(1000);
return 0;
});
}
es.close();
}
}

再将ExecutorService改为线程池模式并对比结果。

下载练习

小结

Java 19引入的虚拟线程是为了解决IO密集型任务的吞吐量,它可以高效通过少数线程去调度大量虚拟线程;

虚拟线程在执行到IO操作或Blocking操作时,会自动切换到其他虚拟线程执行,从而避免当前线程等待,能最大化线程的执行效率;

虚拟线程使用普通线程相同的接口,最大的好处是无需修改任何代码,就可以将现有的IO操作异步化获得更大的吞吐能力。

计算密集型任务不应使用虚拟线程,只能通过增加CPU核心解决,或者利用分布式计算资源。

留言與分享

JAVA-多线程2(线程同步)

分類 编程语言, Java

当多个线程同时运行时,线程的调度由操作系统决定,程序本身无法决定。因此,任何一个线程都有可能在任何指令处被操作系统暂停,然后在某个时间段后继续执行。

这个时候,有个单线程模型下不存在的问题就来了:如果多个线程同时读写共享变量,会出现数据不一致的问题。

我们来看一个例子:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
// 多线程
public class Main {
public static void main(String[] args) throws Exception {
var add = new AddThread();
var dec = new DecThread();
add.start();
dec.start();
add.join();
dec.join();
System.out.println(Counter.count);
}
}

class Counter {
public static int count = 0;
}

class AddThread extends Thread {
public void run() {
for (int i=0; i<10000; i++) { Counter.count += 1; }
}
}

class DecThread extends Thread {
public void run() {
for (int i=0; i<10000; i++) { Counter.count -= 1; }
}
}

上面的代码很简单,两个线程同时对一个int变量进行操作,一个加10000次,一个减10000次,最后结果应该是0,但是,每次运行,结果实际上都是不一样的。

这是因为对变量进行读取和写入时,结果要正确,必须保证是原子操作。原子操作是指不能被中断的一个或一系列操作。

例如,对于语句:

1
n = n + 1;

看上去是一行语句,实际上对应了3条指令:

1
2
3
ILOAD
IADD
ISTORE

我们假设n的值是100,如果两个线程同时执行n = n + 1,得到的结果很可能不是102,而是101,原因在于:

1
2
3
4
5
6
7
8
9
10
11
┌───────┐     ┌───────┐
│Thread1│ │Thread2│
└───┬───┘ └───┬───┘
│ │
│ILOAD (100) │
│ │ILOAD (100)
│ │IADD
│ │ISTORE (101)
│IADD │
│ISTORE (101) │
▼ ▼

如果线程1在执行ILOAD后被操作系统中断,此刻如果线程2被调度执行,它执行ILOAD后获取的值仍然是100,最终结果被两个线程的ISTORE写入后变成了101,而不是期待的102

这说明多线程模型下,要保证逻辑正确,对共享变量进行读写时,必须保证一组指令以原子方式执行:即某一个线程执行时,其他线程必须等待:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
┌───────┐     ┌───────┐
│Thread1│ │Thread2│
└───┬───┘ └───┬───┘
│ │
│-- lock -- │
│ILOAD (100) │
│IADD │
│ISTORE (101) │
│-- unlock -- │
│ │-- lock --
│ │ILOAD (101)
│ │IADD
│ │ISTORE (102)
│ │-- unlock --
▼ ▼

通过加锁和解锁的操作,就能保证3条指令总是在一个线程执行期间,不会有其他线程会进入此指令区间。即使在执行期线程被操作系统中断执行,其他线程也会因为无法获得锁导致无法进入此指令区间。只有执行线程将锁释放后,其他线程才有机会获得锁并执行。这种加锁和解锁之间的代码块我们称之为临界区(Critical Section),任何时候临界区最多只有一个线程能执行。

可见,保证一段代码的原子性就是通过加锁和解锁实现的。Java程序使用synchronized关键字对一个对象进行加锁:

1
2
3
synchronized(lock) {
n = n + 1;
}

synchronized保证了代码块在任意时刻最多只有一个线程能执行。我们把上面的代码用synchronized改写如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
// 多线程
public class Main {
public static void main(String[] args) throws Exception {
var add = new AddThread();
var dec = new DecThread();
add.start();
dec.start();
add.join();
dec.join();
System.out.println(Counter.count);
}
}

class Counter {
public static final Object lock = new Object();
public static int count = 0;
}

class AddThread extends Thread {
public void run() {
for (int i=0; i<10000; i++) {
synchronized(Counter.lock) {
Counter.count += 1;
}
}
}
}

class DecThread extends Thread {
public void run() {
for (int i=0; i<10000; i++) {
synchronized(Counter.lock) {
Counter.count -= 1;
}
}
}
}

注意到代码:

1
2
3
synchronized(Counter.lock) { // 获取锁
...
} // 释放锁

它表示用Counter.lock实例作为锁,两个线程在执行各自的synchronized(Counter.lock) { ... }代码块时,必须先获得锁,才能进入代码块进行。执行结束后,在synchronized语句块结束会自动释放锁。这样一来,对Counter.count变量进行读写就不可能同时进行。上述代码无论运行多少次,最终结果都是0。

使用synchronized解决了多线程同步访问共享变量的正确性问题。但是,它的缺点是带来了性能下降。因为synchronized代码块无法并发执行。此外,加锁和解锁需要消耗一定的时间,所以,synchronized会降低程序的执行效率。

我们来概括一下如何使用synchronized

  1. 找出修改共享变量的线程代码块;
  2. 选择一个共享实例作为锁;
  3. 使用synchronized(lockObject) { ... }

在使用synchronized的时候,不必担心抛出异常。因为无论是否有异常,都会在synchronized结束处正确释放锁:

1
2
3
4
5
6
7
8
public void add(int m) {
synchronized (obj) {
if (m < 0) {
throw new RuntimeException();
}
this.value += m;
} // 无论有无异常,都会在此释放锁
}

我们再来看一个错误使用synchronized的例子:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
// 多线程
public class Main {
public static void main(String[] args) throws Exception {
var add = new AddThread();
var dec = new DecThread();
add.start();
dec.start();
add.join();
dec.join();
System.out.println(Counter.count);
}
}

class Counter {
public static final Object lock1 = new Object();
public static final Object lock2 = new Object();
public static int count = 0;
}

class AddThread extends Thread {
public void run() {
for (int i=0; i<10000; i++) {
synchronized(Counter.lock1) {
Counter.count += 1;
}
}
}
}

class DecThread extends Thread {
public void run() {
for (int i=0; i<10000; i++) {
synchronized(Counter.lock2) {
Counter.count -= 1;
}
}
}
}

结果并不是0,这是因为两个线程各自的synchronized锁住的不是同一个对象!这使得两个线程各自都可以同时获得锁:因为JVM只保证同一个锁在任意时刻只能被一个线程获取,但两个不同的锁在同一时刻可以被两个线程分别获取。

因此,使用synchronized的时候,获取到的是哪个锁非常重要。锁对象如果不对,代码逻辑就不对。

我们再看一个例子:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
// 多线程
public class Main {
public static void main(String[] args) throws Exception {
var ts = new Thread[] { new AddStudentThread(), new DecStudentThread(), new AddTeacherThread(), new DecTeacherThread() };
for (var t : ts) {
t.start();
}
for (var t : ts) {
t.join();
}
System.out.println(Counter.studentCount);
System.out.println(Counter.teacherCount);
}
}

class Counter {
public static final Object lock = new Object();
public static int studentCount = 0;
public static int teacherCount = 0;
}

class AddStudentThread extends Thread {
public void run() {
for (int i=0; i<10000; i++) {
synchronized(Counter.lock) {
Counter.studentCount += 1;
}
}
}
}

class DecStudentThread extends Thread {
public void run() {
for (int i=0; i<10000; i++) {
synchronized(Counter.lock) {
Counter.studentCount -= 1;
}
}
}
}

class AddTeacherThread extends Thread {
public void run() {
for (int i=0; i<10000; i++) {
synchronized(Counter.lock) {
Counter.teacherCount += 1;
}
}
}
}

class DecTeacherThread extends Thread {
public void run() {
for (int i=0; i<10000; i++) {
synchronized(Counter.lock) {
Counter.teacherCount -= 1;
}
}
}
}

上述代码的4个线程对两个共享变量分别进行读写操作,但是使用的锁都是Counter.lock这一个对象,这就造成了原本可以并发执行的Counter.studentCount += 1Counter.teacherCount += 1,现在无法并发执行了,执行效率大大降低。实际上,需要同步的线程可以分成两组:AddStudentThreadDecStudentThreadAddTeacherThreadDecTeacherThread,组之间不存在竞争,因此,应该使用两个不同的锁,即:

AddStudentThreadDecStudentThread使用lockStudent锁:

1
2
3
synchronized(Counter.lockStudent) {
...
}

AddTeacherThreadDecTeacherThread使用lockTeacher锁:

1
2
3
synchronized(Counter.lockTeacher) {
...
}

这样才能最大化地提高执行效率。

不需要synchronized的操作

JVM规范定义了几种原子操作:

  • 基本类型(longdouble除外)赋值,例如:int n = m
  • 引用类型赋值,例如:List<String> list = anotherList

longdouble是64位数据,JVM没有明确规定64位赋值操作是不是一个原子操作,不过在x64平台的JVM是把longdouble的赋值作为原子操作实现的。

单条原子操作的语句不需要同步。例如:

1
2
3
4
5
public void set(int m) {
synchronized(lock) {
this.value = m;
}
}

就不需要同步。

对引用也是类似。例如:

1
2
3
public void set(String s) {
this.value = s;
}

上述赋值语句并不需要同步。

但是,如果是多行赋值语句,就必须保证是同步操作,例如:

1
2
3
4
5
6
7
8
9
10
class Point {
int x;
int y;
public void set(int x, int y) {
synchronized(this) {
this.x = x;
this.y = y;
}
}
}

提示

多线程连续读写多个变量时,同步的目的是为了保证程序逻辑正确!

不但写需要同步,读也需要同步:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class Point {
int x;
int y;

public void set(int x, int y) {
synchronized(this) {
this.x = x;
this.y = y;
}
}

public int[] get() {
int[] copy = new int[2];
copy[0] = x;
copy[1] = y;
}
}

假定当前坐标是(100, 200),那么当设置新坐标为(110, 220)时,上述未同步的多线程读到的值可能有:

  • (100, 200):x,y更新前;
  • (110, 200):x更新后,y更新前;
  • (110, 220):x,y更新后。

如果读取到(110, 200),即读到了更新后的x,更新前的y,那么可能会造成程序的逻辑错误,无法保证读取的多个变量状态保持一致。

有些时候,通过一些巧妙的转换,可以把非原子操作变为原子操作。例如,上述代码如果改造成:

1
2
3
4
5
6
7
class Point {
int[] ps;
public void set(int x, int y) {
int[] ps = new int[] { x, y };
this.ps = ps;
}
}

就不再需要写同步,因为this.ps = ps是引用赋值的原子操作。而语句:

1
int[] ps = new int[] { x, y };

这里的ps是方法内部定义的局部变量,每个线程都会有各自的局部变量,互不影响,并且互不可见,并不需要同步。

不过要注意,读方法在复制int[]数组的过程中仍然需要同步。

不可变对象无需同步

如果多线程读写的是一个不可变对象,那么无需同步,因为不会修改对象的状态:

1
2
3
4
5
6
7
8
9
class Data {
List<String> names;
void set(String[] names) {
this.names = List.of(names);
}
List<String> get() {
return this.names;
}
}

注意到set()方法内部创建了一个不可变List,这个List包含的对象也是不可变对象String,因此,整个List<String>对象都是不可变的,因此读写均无需同步。

分析变量是否能被多线程访问时,首先要理清概念,多线程同时执行的是方法。对于下面这个例子:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class Status {
List<String> names;
int x;
int y;
void set(String[] names, int n) {
List<String> ns = List.of(names);
this.names = ns;
int step = n * 10;
this.x += step;
this.y += step;
}
StatusRecord get() {
return new StatusRecord(this.names, this.x, this.y);
}
}

如果有A、B两个线程,同时执行是指:

  • 可能同时执行set();
  • 可能同时执行get();
  • 可能A执行set(),同时B执行get()。

类的成员变量namesxy显然能被多线程同时读写,但局部变量(包括方法参数)如果没有“逃逸”,那么只有当前线程可见。局部变量step仅在set()方法内部使用,因此每个线程同时执行set时都有一份独立的step存储在线程的栈上,互不影响,但是局部变量ns虽然每个线程也各有一份,但后续赋值后对其他线程就变成可见了。对set()方法同步时,如果要最小化synchronized代码块,可以改写如下:

1
2
3
4
5
6
7
8
9
10
void set(String[] names, int n) {
// 局部变量其他线程不可见:
List<String> ns = List.of(names);
int step = n * 10;
synchronized(this) {
this.names = ns;
this.x += step;
this.y += step;
}
}

因此,深入理解多线程还需理解变量在栈上的存储方式,基本类型和引用类型的存储方式也不同。

小结

多线程同时读写共享变量时,可能会造成逻辑错误,因此需要通过synchronized同步;

同步的本质就是给指定对象加锁,加锁后才能继续执行后续代码;

注意加锁对象必须是同一个实例;

对JVM定义的单个原子操作不需要同步。

我们知道Java程序依靠synchronized对线程进行同步,使用synchronized的时候,锁住的是哪个对象非常重要。

让线程自己选择锁对象往往会使得代码逻辑混乱,也不利于封装。更好的方法是把synchronized逻辑封装起来。例如,我们编写一个计数器如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
public class Counter {
private int count = 0;

public void add(int n) {
synchronized(this) {
count += n;
}
}

public void dec(int n) {
synchronized(this) {
count -= n;
}
}

public int get() {
return count;
}
}

这样一来,线程调用add()dec()方法时,它不必关心同步逻辑,因为synchronized代码块在add()dec()方法内部。并且,我们注意到,synchronized锁住的对象是this,即当前实例,这又使得创建多个Counter实例的时候,它们之间互不影响,可以并发执行:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
var c1 = Counter();
var c2 = Counter();

// 对c1进行操作的线程:
new Thread(() -> {
c1.add();
}).start();
new Thread(() -> {
c1.dec();
}).start();

// 对c2进行操作的线程:
new Thread(() -> {
c2.add();
}).start();
new Thread(() -> {
c2.dec();
}).start();

现在,对于Counter类,多线程可以正确调用。

如果一个类被设计为允许多线程正确访问,我们就说这个类就是“线程安全”的(thread-safe),上面的Counter类就是线程安全的。Java标准库的java.lang.StringBuffer也是线程安全的。

还有一些不变类,例如StringIntegerLocalDate,它们的所有成员变量都是final,多线程同时访问时只能读不能写,这些不变类也是线程安全的。

最后,类似Math这些只提供静态方法,没有成员变量的类,也是线程安全的。

除了上述几种少数情况,大部分类,例如ArrayList,都是非线程安全的类,我们不能在多线程中修改它们。但是,如果所有线程都只读取,不写入,那么ArrayList是可以安全地在线程间共享的。

我们再观察Counter的代码:

1
2
3
4
5
6
7
8
public class Counter {
public void add(int n) {
synchronized(this) {
count += n;
}
}
...
}

当我们锁住的是this实例时,实际上可以用synchronized修饰这个方法。下面两种写法是等价的:

1
2
3
4
5
public void add(int n) {
synchronized(this) { // 锁住this
count += n;
} // 解锁
}

写法二:

1
2
3
public synchronized void add(int n) { // 锁住this
count += n;
} // 解锁

因此,用synchronized修饰的方法就是同步方法,它表示整个方法都必须用this实例加锁。

我们再思考一下,如果对一个静态方法添加synchronized修饰符,它锁住的是哪个对象?

1
2
3
public synchronized static void test(int n) {
...
}

对于static方法,是没有this实例的,因为static方法是针对类而不是实例。但是我们注意到任何一个类都有一个由JVM自动创建的Class实例,因此,对static方法添加synchronized,锁住的是该类的Class实例。上述synchronized static方法实际上相当于:

1
2
3
4
5
6
7
public class Counter {
public static void test(int n) {
synchronized(Counter.class) {
...
}
}
}

我们再考察Counterget()方法:

1
2
3
4
5
6
7
8
public class Counter {
private int count;

public int get() {
return count;
}
...
}

它没有同步,因为读一个int变量不需要同步。

然而,如果我们把代码稍微改一下,返回一个包含两个int的对象:

1
2
3
4
5
6
7
8
9
10
11
12
public class Counter {
private int first;
private int last;

public Pair get() {
Pair p = new Pair();
p.first = first;
p.last = last;
return p;
}
...
}

就必须要同步了。

小结

synchronized修饰方法可以把整个方法变为同步代码块,synchronized方法加锁对象是this

通过合理的设计和数据封装可以让一个类变为“线程安全”;

一个类没有特殊说明,默认不是thread-safe;

多线程能否安全访问某个非线程安全的实例,需要具体问题具体分析。

死锁

Java的线程锁是可重入的锁。

什么是可重入的锁?我们还是来看例子:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
public class Counter {
private int count = 0;

public synchronized void add(int n) {
if (n < 0) {
dec(-n);
} else {
count += n;
}
}

public synchronized void dec(int n) {
count += n;
}
}

观察synchronized修饰的add()方法,一旦线程执行到add()方法内部,说明它已经获取了当前实例的this锁。如果传入的n < 0,将在add()方法内部调用dec()方法。由于dec()方法也需要获取this锁,现在问题来了:

对同一个线程,能否在获取到锁以后继续获取同一个锁?

答案是肯定的。JVM允许同一个线程重复获取同一个锁,这种能被同一个线程反复获取的锁,就叫做可重入锁

由于Java的线程锁是可重入锁,所以,获取锁的时候,不但要判断是否是第一次获取,还要记录这是第几次获取。每获取一次锁,记录+1,每退出synchronized块,记录-1,减到0的时候,才会真正释放锁。

死锁

一个线程可以获取一个锁后,再继续获取另一个锁。例如:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
public void add(int m) {
synchronized(lockA) { // 获得lockA的锁
this.value += m;
synchronized(lockB) { // 获得lockB的锁
this.another += m;
} // 释放lockB的锁
} // 释放lockA的锁
}

public void dec(int m) {
synchronized(lockB) { // 获得lockB的锁
this.another -= m;
synchronized(lockA) { // 获得lockA的锁
this.value -= m;
} // 释放lockA的锁
} // 释放lockB的锁
}

在获取多个锁的时候,不同线程获取多个不同对象的锁可能导致死锁。对于上述代码,线程1和线程2如果分别执行add()dec()方法时:

  • 线程1:进入add(),获得lockA
  • 线程2:进入dec(),获得lockB

随后:

  • 线程1:准备获得lockB,失败,等待中;
  • 线程2:准备获得lockA,失败,等待中。

此时,两个线程各自持有不同的锁,然后各自试图获取对方手里的锁,造成了双方无限等待下去,这就是死锁。

死锁发生后,没有任何机制能解除死锁,只能强制结束JVM进程。

因此,在编写多线程应用时,要特别注意防止死锁。因为死锁一旦形成,就只能强制结束进程。

那么我们应该如何避免死锁呢?答案是:线程获取锁的顺序要一致。即严格按照先获取lockA,再获取lockB的顺序,改写dec()方法如下:

1
2
3
4
5
6
7
8
public void dec(int m) {
synchronized(lockA) { // 获得lockA的锁
this.value -= m;
synchronized(lockB) { // 获得lockB的锁
this.another -= m;
} // 释放lockB的锁
} // 释放lockA的锁
}

练习

请观察死锁的代码输出,然后修复。

下载练习

小结

Java的synchronized锁是可重入锁;

死锁产生的条件是多线程各自持有不同的锁,并互相试图获取对方已持有的锁,导致无限等待;

避免死锁的方法是多线程获取锁的顺序要一致。



在Java程序中,synchronized解决了多线程竞争的问题。例如,对于一个任务管理器,多个线程同时往队列中添加任务,可以用synchronized加锁:

1
2
3
4
5
6
7
class TaskQueue {
Queue<String> queue = new LinkedList<>();

public synchronized void addTask(String s) {
this.queue.add(s);
}
}

但是synchronized并没有解决多线程协调的问题。

仍然以上面的TaskQueue为例,我们再编写一个getTask()方法取出队列的第一个任务:

1
2
3
4
5
6
7
8
9
10
11
12
13
class TaskQueue {
Queue<String> queue = new LinkedList<>();

public synchronized void addTask(String s) {
this.queue.add(s);
}

public synchronized String getTask() {
while (queue.isEmpty()) {
}
return queue.remove();
}
}

上述代码看上去没有问题:getTask()内部先判断队列是否为空,如果为空,就循环等待,直到另一个线程往队列中放入了一个任务,while()循环退出,就可以返回队列的元素了。

但实际上while()循环永远不会退出。因为线程在执行while()循环时,已经在getTask()入口获取了this锁,其他线程根本无法调用addTask(),因为addTask()执行条件也是获取this锁。

因此,执行上述代码,线程会在getTask()中因为死循环而100%占用CPU资源。

如果深入思考一下,我们想要的执行效果是:

  • 线程1可以调用addTask()不断往队列中添加任务;
  • 线程2可以调用getTask()从队列中获取任务。如果队列为空,则getTask()应该等待,直到队列中至少有一个任务时再返回。

因此,多线程协调运行的原则就是:当条件不满足时,线程进入等待状态;当条件满足时,线程被唤醒,继续执行任务。

对于上述TaskQueue,我们先改造getTask()方法,在条件不满足时,线程进入等待状态:

1
2
3
4
5
6
public synchronized String getTask() {
while (queue.isEmpty()) {
this.wait();
}
return queue.remove();
}

当一个线程执行到getTask()方法内部的while循环时,它必定已经获取到了this锁,此时,线程执行while条件判断,如果条件成立(队列为空),线程将执行this.wait(),进入等待状态。

这里的关键是:wait()方法必须在当前获取的锁对象上调用,这里获取的是this锁,因此调用this.wait()

调用wait()方法后,线程进入等待状态,wait()方法不会返回,直到将来某个时刻,线程从等待状态被其他线程唤醒后,wait()方法才会返回,然后,继续执行下一条语句。

有些仔细的童鞋会指出:即使线程在getTask()内部等待,其他线程如果拿不到this锁,照样无法执行addTask(),肿么办?

这个问题的关键就在于wait()方法的执行机制非常复杂。首先,它不是一个普通的Java方法,而是定义在Object类的一个native方法,也就是由JVM的C代码实现的。其次,必须在synchronized块中才能调用wait()方法,因为wait()方法调用时,会释放线程获得的锁,wait()方法返回时,线程又会重新试图获得锁。

因此,只能在锁对象上调用wait()方法。因为在getTask()中,我们获得了this锁,因此,只能在this对象上调用wait()方法:

1
2
3
4
5
6
7
8
public synchronized String getTask() {
while (queue.isEmpty()) {
// 释放this锁:
this.wait();
// 重新获取this锁
}
return queue.remove();
}

当一个线程在this.wait()等待时,它就会释放this锁,从而使得其他线程能够在addTask()方法获得this锁。

现在我们面临第二个问题:如何让等待的线程被重新唤醒,然后从wait()方法返回?答案是在相同的锁对象上调用notify()方法。我们修改addTask()如下:

1
2
3
4
public synchronized void addTask(String s) {
this.queue.add(s);
this.notify(); // 唤醒在this锁等待的线程
}

注意到在往队列中添加了任务后,线程立刻对this锁对象调用notify()方法,这个方法会唤醒一个正在this锁等待的线程(就是在getTask()中位于this.wait()的线程),从而使得等待线程从this.wait()方法返回。

我们来看一个完整的例子:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import java.util.*;

public class Main {
public static void main(String[] args) throws InterruptedException {
var q = new TaskQueue();
var ts = new ArrayList<Thread>();
for (int i=0; i<5; i++) {
var t = new Thread() {
public void run() {
// 执行task:
while (true) {
try {
String s = q.getTask();
System.out.println("execute task: " + s);
} catch (InterruptedException e) {
return;
}
}
}
};
t.start();
ts.add(t);
}
var add = new Thread(() -> {
for (int i=0; i<10; i++) {
// 放入task:
String s = "t-" + Math.random();
System.out.println("add task: " + s);
q.addTask(s);
try { Thread.sleep(100); } catch(InterruptedException e) {}
}
});
add.start();
add.join();
Thread.sleep(100);
for (var t : ts) {
t.interrupt();
}
}
}

class TaskQueue {
Queue<String> queue = new LinkedList<>();

public synchronized void addTask(String s) {
this.queue.add(s);
this.notifyAll();
}

public synchronized String getTask() throws InterruptedException {
while (queue.isEmpty()) {
this.wait();
}
return queue.remove();
}
}

这个例子中,我们重点关注addTask()方法,内部调用了this.notifyAll()而不是this.notify(),使用notifyAll()将唤醒所有当前正在this锁等待的线程,而notify()只会唤醒其中一个(具体哪个依赖操作系统,有一定的随机性)。这是因为可能有多个线程正在getTask()方法内部的wait()中等待,使用notifyAll()将一次性全部唤醒。通常来说,notifyAll()更安全。有些时候,如果我们的代码逻辑考虑不周,用notify()会导致只唤醒了一个线程,而其他线程可能永远等待下去醒不过来了。

但是,注意到wait()方法返回时需要重新获得this锁。假设当前有3个线程被唤醒,唤醒后,首先要等待执行addTask()的线程结束此方法后,才能释放this锁,随后,这3个线程中只能有一个获取到this锁,剩下两个将继续等待。

再注意到我们在while()循环中调用wait(),而不是if语句:

1
2
3
4
5
6
public synchronized String getTask() throws InterruptedException {
if (queue.isEmpty()) {
this.wait();
}
return queue.remove();
}

这种写法实际上是错误的,因为线程被唤醒时,需要再次获取this锁。多个线程被唤醒后,只有一个线程能获取this锁,此刻,该线程执行queue.remove()可以获取到队列的元素,然而,剩下的线程如果获取this锁后执行queue.remove(),此刻队列可能已经没有任何元素了,所以,要始终在while循环中wait(),并且每次被唤醒后拿到this锁就必须再次判断:

1
2
3
while (queue.isEmpty()) {
this.wait();
}

所以,正确编写多线程代码是非常困难的,需要仔细考虑的条件非常多,任何一个地方考虑不周,都会导致多线程运行时不正常。

multithread

小结

waitnotify用于多线程协调运行:

  • synchronized内部可以调用wait()使线程进入等待状态;
  • 必须在已获得的锁对象上调用wait()方法;
  • synchronized内部可以调用notify()notifyAll()唤醒其他等待线程;
  • 必须在已获得的锁对象上调用notify()notifyAll()方法;
  • 已唤醒的线程还需要重新获得锁后才能继续执行。

使用ReentrantLock

从Java 5开始,引入了一个高级的处理并发的java.util.concurrent包,它提供了大量更高级的并发功能,能大大简化多线程程序的编写。

我们知道Java语言直接提供了synchronized关键字用于加锁,但这种锁一是很重,二是获取时必须一直等待,没有额外的尝试机制。

java.util.concurrent.locks包提供的ReentrantLock用于替代synchronized加锁,我们来看一下传统的synchronized代码:

1
2
3
4
5
6
7
8
9
public class Counter {
private int count;

public void add(int n) {
synchronized(this) {
count += n;
}
}
}

如果用ReentrantLock替代,可以把代码改造为:

1
2
3
4
5
6
7
8
9
10
11
12
13
public class Counter {
private final Lock lock = new ReentrantLock();
private int count;

public void add(int n) {
lock.lock();
try {
count += n;
} finally {
lock.unlock();
}
}
}

因为synchronized是Java语言层面提供的语法,所以我们不需要考虑异常,而ReentrantLock是Java代码实现的锁,我们就必须先获取锁,然后在finally中正确释放锁。

顾名思义,ReentrantLock是可重入锁,它和synchronized一样,一个线程可以多次获取同一个锁。

synchronized不同的是,ReentrantLock可以尝试获取锁:

1
2
3
4
5
6
7
if (lock.tryLock(1, TimeUnit.SECONDS)) {
try {
...
} finally {
lock.unlock();
}
}

上述代码在尝试获取锁的时候,最多等待1秒。如果1秒后仍未获取到锁,tryLock()返回false,程序就可以做一些额外处理,而不是无限等待下去。

所以,使用ReentrantLock比直接使用synchronized更安全,线程在tryLock()失败的时候不会导致死锁。

小结

ReentrantLock可以替代synchronized进行同步;

ReentrantLock获取锁更安全;

必须先获取到锁,再进入try {...}代码块,最后使用finally保证释放锁;

可以使用tryLock()尝试获取锁。



使用Condition

使用ReentrantLock比直接使用synchronized更安全,可以替代synchronized进行线程同步。

但是,synchronized可以配合waitnotify实现线程在条件不满足时等待,条件满足时唤醒,用ReentrantLock我们怎么编写waitnotify的功能呢?

答案是使用Condition对象来实现waitnotify的功能。

我们仍然以TaskQueue为例,把前面用synchronized实现的功能通过ReentrantLockCondition来实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class TaskQueue {
private final Lock lock = new ReentrantLock();
private final Condition condition = lock.newCondition();
private Queue<String> queue = new LinkedList<>();

public void addTask(String s) {
lock.lock();
try {
queue.add(s);
condition.signalAll();
} finally {
lock.unlock();
}
}

public String getTask() {
lock.lock();
try {
while (queue.isEmpty()) {
condition.await();
}
return queue.remove();
} finally {
lock.unlock();
}
}
}

可见,使用Condition时,引用的Condition对象必须从Lock实例的newCondition()返回,这样才能获得一个绑定了Lock实例的Condition实例。

Condition提供的await()signal()signalAll()原理和synchronized锁对象的wait()notify()notifyAll()是一致的,并且其行为也是一样的:

  • await()会释放当前锁,进入等待状态;
  • signal()会唤醒某个等待线程;
  • signalAll()会唤醒所有等待线程;
  • 唤醒线程从await()返回后需要重新获得锁。

此外,和tryLock()类似,await()可以在等待指定时间后,如果还没有被其他线程通过signal()signalAll()唤醒,可以自己醒来:

1
2
3
4
5
if (condition.await(1, TimeUnit.SECOND)) {
// 被其他线程唤醒
} else {
// 指定时间内没有被其他线程唤醒
}

可见,使用Condition配合Lock,我们可以实现更灵活的线程同步。

小结

Condition可以替代waitnotify

Condition对象必须从Lock对象获取。



使用ReadWriteLock

前面讲到的ReentrantLock保证了只有一个线程可以执行临界区代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
public class Counter {
private final Lock lock = new ReentrantLock();
private int[] counts = new int[10];

public void inc(int index) {
lock.lock();
try {
counts[index] += 1;
} finally {
lock.unlock();
}
}

public int[] get() {
lock.lock();
try {
return Arrays.copyOf(counts, counts.length);
} finally {
lock.unlock();
}
}
}

但是有些时候,这种保护有点过头。因为我们发现,任何时刻,只允许一个线程修改,也就是调用inc()方法是必须获取锁,但是,get()方法只读取数据,不修改数据,它实际上允许多个线程同时调用。

实际上我们想要的是:允许多个线程同时读,但只要有一个线程在写,其他线程就必须等待:

允许 不允许
不允许 不允许

使用ReadWriteLock可以解决这个问题,它保证:

  • 只允许一个线程写入(其他线程既不能写入也不能读取);
  • 没有写入时,多个线程允许同时读(提高性能)。

ReadWriteLock实现这个功能十分容易。我们需要创建一个ReadWriteLock实例,然后分别获取读锁和写锁:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
public class Counter {
private final ReadWriteLock rwlock = new ReentrantReadWriteLock();
// 注意: 一对读锁和写锁必须从同一个rwlock获取:
private final Lock rlock = rwlock.readLock();
private final Lock wlock = rwlock.writeLock();
private int[] counts = new int[10];

public void inc(int index) {
wlock.lock(); // 加写锁
try {
counts[index] += 1;
} finally {
wlock.unlock(); // 释放写锁
}
}

public int[] get() {
rlock.lock(); // 加读锁
try {
return Arrays.copyOf(counts, counts.length);
} finally {
rlock.unlock(); // 释放读锁
}
}
}

把读写操作分别用读锁和写锁来加锁,在读取时,多个线程可以同时获得读锁,这样就大大提高了并发读的执行效率。

使用ReadWriteLock时,适用条件是同一个数据,有大量线程读取,但仅有少数线程修改。

例如,一个论坛的帖子,回复可以看做写入操作,它是不频繁的,但是,浏览可以看做读取操作,是非常频繁的,这种情况就可以使用ReadWriteLock

小结

使用ReadWriteLock可以提高读取效率:

  • ReadWriteLock只允许一个线程写入;
  • ReadWriteLock允许多个线程在没有写入时同时读取;
  • ReadWriteLock适合读多写少的场景。


前面介绍的ReadWriteLock可以解决多线程同时读,但只有一个线程能写的问题。

如果我们深入分析ReadWriteLock,会发现它有个潜在的问题:如果有线程正在读,写线程需要等待读线程释放锁后才能获取写锁,即读的过程中不允许写,这是一种悲观的读锁。

要进一步提升并发执行效率,Java 8引入了新的读写锁:StampedLock

StampedLockReadWriteLock相比,改进之处在于:读的过程中也允许获取写锁后写入!这样一来,我们读的数据就可能不一致,所以,需要一点额外的代码来判断读的过程中是否有写入,这种读锁是一种乐观锁。

乐观锁的意思就是乐观地估计读的过程中大概率不会有写入,因此被称为乐观锁。反过来,悲观锁则是读的过程中拒绝有写入,也就是写入必须等待。显然乐观锁的并发效率更高,但一旦有小概率的写入导致读取的数据不一致,需要能检测出来,再读一遍就行。

我们来看例子:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
public class Point {
private final StampedLock stampedLock = new StampedLock();

private double x;
private double y;

public void move(double deltaX, double deltaY) {
long stamp = stampedLock.writeLock(); // 获取写锁
try {
x += deltaX;
y += deltaY;
} finally {
stampedLock.unlockWrite(stamp); // 释放写锁
}
}

public double distanceFromOrigin() {
long stamp = stampedLock.tryOptimisticRead(); // 获得一个乐观读锁
// 注意下面两行代码不是原子操作
// 假设x,y = (100,200)
double currentX = x;
// 此处已读取到x=100,但x,y可能被写线程修改为(300,400)
double currentY = y;
// 此处已读取到y,如果没有写入,读取是正确的(100,200)
// 如果有写入,读取是错误的(100,400)
if (!stampedLock.validate(stamp)) { // 检查乐观读锁后是否有其他写锁发生
stamp = stampedLock.readLock(); // 获取一个悲观读锁
try {
currentX = x;
currentY = y;
} finally {
stampedLock.unlockRead(stamp); // 释放悲观读锁
}
}
return Math.sqrt(currentX * currentX + currentY * currentY);
}
}

ReadWriteLock相比,写入的加锁是完全一样的,不同的是读取。注意到首先我们通过tryOptimisticRead()获取一个乐观读锁,并返回版本号。接着进行读取,读取完成后,我们通过validate()去验证版本号,如果在读取过程中没有写入,版本号不变,验证成功,我们就可以放心地继续后续操作。如果在读取过程中有写入,版本号会发生变化,验证将失败。在失败的时候,我们再通过获取悲观读锁再次读取。由于写入的概率不高,程序在绝大部分情况下可以通过乐观读锁获取数据,极少数情况下使用悲观读锁获取数据。

可见,StampedLock把读锁细分为乐观读和悲观读,能进一步提升并发效率。但这也是有代价的:一是代码更加复杂,二是StampedLock是不可重入锁,不能在一个线程中反复获取同一个锁。

StampedLock还提供了更复杂的将悲观读锁升级为写锁的功能,它主要使用在if-then-update的场景:即先读,如果读的数据满足条件,就返回,如果读的数据不满足条件,再尝试写。

小结

StampedLock提供了乐观读锁,可取代ReadWriteLock以进一步提升并发性能;

StampedLock是不可重入锁。

使用Semaphore

前面我们讲了各种锁的实现,本质上锁的目的是保护一种受限资源,保证同一时刻只有一个线程能访问(ReentrantLock),或者只有一个线程能写入(ReadWriteLock)。

还有一种受限资源,它需要保证同一时刻最多有N个线程能访问,比如同一时刻最多创建100个数据库连接,最多允许10个用户下载等。

这种限制数量的锁,如果用Lock数组来实现,就太麻烦了。

这种情况就可以使用Semaphore,例如,最多允许3个线程同时访问:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
public class AccessLimitControl {
// 任意时刻仅允许最多3个线程获取许可:
final Semaphore semaphore = new Semaphore(3);

public String access() throws Exception {
// 如果超过了许可数量,其他线程将在此等待:
semaphore.acquire();
try {
// TODO:
return UUID.randomUUID().toString();
} finally {
semaphore.release();
}
}
}

使用Semaphore先调用acquire()获取,然后通过try ... finally保证在finally中释放。

调用acquire()可能会进入等待,直到满足条件为止。也可以使用tryAcquire()指定等待时间:

1
2
3
4
5
6
7
8
if (semaphore.tryAcquire(3, TimeUnit.SECONDS)) {
// 指定等待时间3秒内获取到许可:
try {
// TODO:
} finally {
semaphore.release();
}
}

Semaphore本质上就是一个信号计数器,用于限制同一时间的最大访问数量。

小结

如果要对某一受限资源进行限流访问,可以使用Semaphore,保证同一时间最多N个线程访问受限资源。



使用Concurrent集合

我们在前面已经通过ReentrantLockCondition实现了一个BlockingQueue

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
public class TaskQueue {
private final Lock lock = new ReentrantLock();
private final Condition condition = lock.newCondition();
private Queue<String> queue = new LinkedList<>();

public void addTask(String s) {
lock.lock();
try {
queue.add(s);
condition.signalAll();
} finally {
lock.unlock();
}
}

public String getTask() {
lock.lock();
try {
while (queue.isEmpty()) {
condition.await();
}
return queue.remove();
} finally {
lock.unlock();
}
}
}

BlockingQueue的意思就是说,当一个线程调用这个TaskQueuegetTask()方法时,该方法内部可能会让线程变成等待状态,直到队列条件满足不为空,线程被唤醒后,getTask()方法才会返回。

因为BlockingQueue非常有用,所以我们不必自己编写,可以直接使用Java标准库的java.util.concurrent包提供的线程安全的集合:ArrayBlockingQueue

除了BlockingQueue外,针对ListMapSetDeque等,java.util.concurrent包也提供了对应的并发集合类。我们归纳一下:

interface non-thread-safe thread-safe
List ArrayList CopyOnWriteArrayList
Map HashMap ConcurrentHashMap
Set HashSet / TreeSet CopyOnWriteArraySet
Queue ArrayDeque / LinkedList ArrayBlockingQueue / LinkedBlockingQueue
Deque ArrayDeque / LinkedList LinkedBlockingDeque

使用这些并发集合与使用非线程安全的集合类完全相同。我们以ConcurrentHashMap为例:

1
2
3
4
5
Map<String, String> map = new ConcurrentHashMap<>();
// 在不同的线程读写:
map.put("A", "1");
map.put("B", "2");
map.get("A", "1");

因为所有的同步和加锁的逻辑都在集合内部实现,对外部调用者来说,只需要正常按接口引用,其他代码和原来的非线程安全代码完全一样。即当我们需要多线程访问时,把:

1
Map<String, String> map = new HashMap<>();

改为:

1
Map<String, String> map = new ConcurrentHashMap<>();

就可以了。

java.util.Collections工具类还提供了一个旧的线程安全集合转换器,可以这么用:

1
2
Map unsafeMap = new HashMap();
Map threadSafeMap = Collections.synchronizedMap(unsafeMap);

但是它实际上是用一个包装类包装了非线程安全的Map,然后对所有读写方法都用synchronized加锁,这样获得的线程安全集合的性能比java.util.concurrent集合要低很多,所以不推荐使用。

小结

使用java.util.concurrent包提供的线程安全的并发集合可以大大简化多线程编程:

多线程同时读写并发集合是安全的;

尽量使用Java标准库提供的并发集合,避免自己编写同步代码。



使用Atomic

Java的java.util.concurrent包除了提供底层锁、并发集合外,还提供了一组原子操作的封装类,它们位于java.util.concurrent.atomic包。

我们以AtomicInteger为例,它提供的主要操作有:

  • 增加值并返回新值:int addAndGet(int delta)
  • 加1后返回新值:int incrementAndGet()
  • 获取当前值:int get()
  • 用CAS方式设置:int compareAndSet(int expect, int update)

Atomic类是通过无锁(lock-free)的方式实现的线程安全(thread-safe)访问。它的主要原理是利用了CAS:Compare and Set。

如果我们自己通过CAS编写incrementAndGet(),它大概长这样:

1
2
3
4
5
6
7
8
public int incrementAndGet(AtomicInteger var) {
int prev, next;
do {
prev = var.get();
next = prev + 1;
} while ( ! var.compareAndSet(prev, next));
return next;
}

CAS是指,在这个操作中,如果AtomicInteger的当前值是prev,那么就更新为next,返回true。如果AtomicInteger的当前值不是prev,就什么也不干,返回false。通过CAS操作并配合do ... while循环,即使其他线程修改了AtomicInteger的值,最终的结果也是正确的。

我们利用AtomicLong可以编写一个多线程安全的全局唯一ID生成器:

1
2
3
4
5
6
7
class IdGenerator {
AtomicLong var = new AtomicLong(0);

public long getNextId() {
return var.incrementAndGet();
}
}

通常情况下,我们并不需要直接用do ... while循环调用compareAndSet实现复杂的并发操作,而是用incrementAndGet()这样的封装好的方法,因此,使用起来非常简单。

在高度竞争的情况下,还可以使用Java 8提供的LongAdderLongAccumulator

小结

使用java.util.concurrent.atomic提供的原子操作可以简化多线程编程:

  • 原子操作实现了无锁的线程安全;
  • 适用于计数器,累加器等。


留言與分享

JAVA-多线程1

分類 编程语言, Java

现代操作系统(Windows,macOS,Linux)都可以执行多任务。多任务就是同时运行多个任务,例如:

multitask

CPU执行代码都是一条一条顺序执行的,但是,即使是单核cpu,也可以同时运行多个任务。因为操作系统执行多任务实际上就是让CPU对多个任务轮流交替执行。

例如,假设我们有语文、数学、英语3门作业要做,每个作业需要30分钟。我们把这3门作业看成是3个任务,可以做1分钟语文作业,再做1分钟数学作业,再做1分钟英语作业:

hurry

这样轮流做下去,在某些人眼里看来,做作业的速度就非常快,看上去就像同时在做3门作业一样

ooops

类似的,操作系统轮流让多个任务交替执行,例如,让浏览器执行0.001秒,让QQ执行0.001秒,再让音乐播放器执行0.001秒,在人看来,CPU就是在同时执行多个任务。

即使是多核CPU,因为通常任务的数量远远多于CPU的核数,所以任务也是交替执行的。

进程

在计算机中,我们把一个任务称为一个进程,浏览器就是一个进程,视频播放器是另一个进程,类似的,音乐播放器和Word都是进程。

某些进程内部还需要同时执行多个子任务。例如,我们在使用Word时,Word可以让我们一边打字,一边进行拼写检查,同时还可以在后台进行打印,我们把子任务称为线程。

进程和线程的关系就是:一个进程可以包含一个或多个线程,但至少会有一个线程。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
                        ┌──────────┐
│Process │
│┌────────┐│
┌──────────┐││ Thread ││┌──────────┐
│Process ││└────────┘││Process │
│┌────────┐││┌────────┐││┌────────┐│
┌──────────┐││ Thread ││││ Thread ││││ Thread ││
│Process ││└────────┘││└────────┘││└────────┘│
│┌────────┐││┌────────┐││┌────────┐││┌────────┐│
││ Thread ││││ Thread ││││ Thread ││││ Thread ││
│└────────┘││└────────┘││└────────┘││└────────┘│
└──────────┘└──────────┘└──────────┘└──────────┘
┌──────────────────────────────────────────────┐
│ Operating System │
└──────────────────────────────────────────────┘

操作系统调度的最小任务单位其实不是进程,而是线程。常用的Windows、Linux等操作系统都采用抢占式多任务,如何调度线程完全由操作系统决定,程序自己不能决定什么时候执行,以及执行多长时间。

因为同一个应用程序,既可以有多个进程,也可以有多个线程,因此,实现多任务的方法,有以下几种:

多进程模式(每个进程只有一个线程):

1
2
3
4
5
6
┌──────────┐ ┌──────────┐ ┌──────────┐
│Process │ │Process │ │Process │
│┌────────┐│ │┌────────┐│ │┌────────┐│
││ Thread ││ ││ Thread ││ ││ Thread ││
│└────────┘│ │└────────┘│ │└────────┘│
└──────────┘ └──────────┘ └──────────┘

多线程模式(一个进程有多个线程):

1
2
3
4
5
6
7
8
9
┌────────────────────┐
│Process │
│┌────────┐┌────────┐│
││ Thread ││ Thread ││
│└────────┘└────────┘│
│┌────────┐┌────────┐│
││ Thread ││ Thread ││
│└────────┘└────────┘│
└────────────────────┘

多进程+多线程模式(复杂度最高):

1
2
3
4
5
6
7
8
9
┌──────────┐┌──────────┐┌──────────┐
│Process ││Process ││Process │
│┌────────┐││┌────────┐││┌────────┐│
││ Thread ││││ Thread ││││ Thread ││
│└────────┘││└────────┘││└────────┘│
│┌────────┐││┌────────┐││┌────────┐│
││ Thread ││││ Thread ││││ Thread ││
│└────────┘││└────────┘││└────────┘│
└──────────┘└──────────┘└──────────┘

进程 vs 线程

进程和线程是包含关系,但是多任务既可以由多进程实现,也可以由单进程内的多线程实现,还可以混合多进程+多线程。

具体采用哪种方式,要考虑到进程和线程的特点。

和多线程相比,多进程的缺点在于:

  • 创建进程比创建线程开销大,尤其是在Windows系统上;
  • 进程间通信比线程间通信要慢,因为线程间通信就是读写同一个变量,速度很快。

而多进程的优点在于:

多进程稳定性比多线程高,因为在多进程的情况下,一个进程崩溃不会影响其他进程,而在多线程的情况下,任何一个线程崩溃会直接导致整个进程崩溃。

多线程

Java语言内置了多线程支持:一个Java程序实际上是一个JVM进程,JVM进程用一个主线程来执行main()方法,在main()方法内部,我们又可以启动多个线程。此外,JVM还有负责垃圾回收的其他工作线程等。

因此,对于大多数Java程序来说,我们说多任务,实际上是说如何使用多线程实现多任务。

和单线程相比,多线程编程的特点在于:多线程经常需要读写共享数据,并且需要同步。例如,播放电影时,就必须由一个线程播放视频,另一个线程播放音频,两个线程需要协调运行,否则画面和声音就不同步。因此,多线程编程的复杂度高,调试更困难。

Java多线程编程的特点又在于:

  • 多线程模型是Java程序最基本的并发模型;
  • 后续读写网络、数据库、Web开发等都依赖Java多线程模型。

因此,必须掌握Java多线程编程才能继续深入学习其他内容。

Java语言内置了多线程支持。当Java程序启动的时候,实际上是启动了一个JVM进程,然后,JVM启动主线程来执行main()方法。在main()方法中,我们又可以启动其他线程。

要创建一个新线程非常容易,我们需要实例化一个Thread实例,然后调用它的start()方法:

1
2
3
4
5
6
7
// 多线程
public class Main {
public static void main(String[] args) {
Thread t = new Thread();
t.start(); // 启动新线程
}
}

但是这个线程启动后实际上什么也不做就立刻结束了。我们希望新线程能执行指定的代码,有以下几种方法:

方法一:从Thread派生一个自定义类,然后覆写run()方法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
// 多线程
public class Main {
public static void main(String[] args) {
Thread t = new MyThread();
t.start(); // 启动新线程
}
}

class MyThread extends Thread {
@Override
public void run() {
System.out.println("start new thread!");
}
}

执行上述代码,注意到start()方法会在内部自动调用实例的run()方法。

方法二:创建Thread实例时,传入一个Runnable实例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
// 多线程
public class Main {
public static void main(String[] args) {
Thread t = new Thread(new MyRunnable());
t.start(); // 启动新线程
}
}

class MyRunnable implements Runnable {
@Override
public void run() {
System.out.println("start new thread!");
}
}

或者用Java 8引入的lambda语法进一步简写为:

1
2
3
4
5
6
7
8
9
// 多线程
public class Main {
public static void main(String[] args) {
Thread t = new Thread(() -> {
System.out.println("start new thread!");
});
t.start(); // 启动新线程
}
}

有童鞋会问,使用线程执行的打印语句,和直接在main()方法执行有区别吗?

区别大了去了。我们看以下代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
public class Main {
public static void main(String[] args) {
System.out.println("main start...");
Thread t = new Thread() {
public void run() {
System.out.println("thread run...");
System.out.println("thread end.");
}
};
t.start();
System.out.println("main end...");
}
}

我们用蓝色表示主线程,也就是main线程,main线程执行的代码有4行,首先打印main start,然后创建Thread对象,紧接着调用start()启动新线程。当start()方法被调用时,JVM就创建了一个新线程,我们通过实例变量t来表示这个新线程对象,并开始执行。

接着,main线程继续执行打印main end语句,而t线程在main线程执行的同时会并发执行,打印thread runthread end语句。

run()方法结束时,新线程就结束了。而main()方法结束时,主线程也结束了。

我们再来看线程的执行顺序:

  1. main线程肯定是先打印main start,再打印main end
  2. t线程肯定是先打印thread run,再打印thread end

但是,除了可以肯定,main start会先打印外,main end打印在thread run之前、thread end之后或者之间,都无法确定。因为从t线程开始运行以后,两个线程就开始同时运行了,并且由操作系统调度,程序本身无法确定线程的调度顺序。

要模拟并发执行的效果,我们可以在线程中调用Thread.sleep(),强迫当前线程暂停一段时间:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
// 多线程
public class Main {
public static void main(String[] args) {
System.out.println("main start...");
Thread t = new Thread() {
public void run() {
System.out.println("thread run...");
try {
Thread.sleep(10);
} catch (InterruptedException e) {}
System.out.println("thread end.");
}
};
t.start();
try {
Thread.sleep(20);
} catch (InterruptedException e) {}
System.out.println("main end...");
}
}

sleep()传入的参数是毫秒。调整暂停时间的大小,我们可以看到main线程和t线程执行的先后顺序。

要特别注意:直接调用Thread实例的run()方法是无效的:

1
2
3
4
5
6
7
8
9
10
11
12
public class Main {
public static void main(String[] args) {
Thread t = new MyThread();
t.run();
}
}

class MyThread extends Thread {
public void run() {
System.out.println("hello");
}
}

直接调用run()方法,相当于调用了一个普通的Java方法,当前线程并没有任何改变,也不会启动新线程。上述代码实际上是在main()方法内部又调用了run()方法,打印hello语句是在main线程中执行的,没有任何新线程被创建。

必须调用Thread实例的start()方法才能启动新线程,如果我们查看Thread类的源代码,会看到start()方法内部调用了一个private native void start0()方法,native修饰符表示这个方法是由JVM虚拟机内部的C代码实现的,不是由Java代码实现的。

线程的优先级

可以对线程设定优先级,设定优先级的方法是:

1
Thread.setPriority(int n) // 1~10, 默认值5

JVM自动把1(低)~10(高)的优先级映射到操作系统实际优先级上(不同操作系统有不同的优先级数量)。优先级高的线程被操作系统调度的优先级较高,操作系统对高优先级线程可能调度更频繁,但我们决不能通过设置优先级来确保高优先级的线程一定会先执行。

练习

创建一个新线程。

下载练习

小结

Java用Thread对象表示一个线程,通过调用start()启动一个新线程;

一个线程对象只能调用一次start()方法;

线程的执行代码写在run()方法中;

线程调度由操作系统决定,程序本身无法决定调度顺序;

Thread.sleep()可以把当前线程暂停一段时间。

线程的状态

在Java程序中,一个线程对象只能调用一次start()方法启动新线程,并在新线程中执行run()方法。一旦run()方法执行完毕,线程就结束了。因此,Java线程的状态有以下几种:

  • New:新创建的线程,尚未执行;
  • Runnable:运行中的线程,正在执行run()方法的Java代码;
  • Blocked:运行中的线程,因为某些操作被阻塞而挂起;
  • Waiting:运行中的线程,因为某些操作在等待中;
  • Timed Waiting:运行中的线程,因为执行sleep()方法正在计时等待;
  • Terminated:线程已终止,因为run()方法执行完毕。

用一个状态转移图表示如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
         ┌─────────────┐
│ New │
└─────────────┘


┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐
┌─────────────┐ ┌─────────────┐
││ Runnable │ │ Blocked ││
└─────────────┘ └─────────────┘
│┌─────────────┐ ┌─────────────┐│
│ Waiting │ │Timed Waiting│
│└─────────────┘ └─────────────┘│
─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─


┌─────────────┐
│ Terminated │
└─────────────┘

当线程启动后,它可以在RunnableBlockedWaitingTimed Waiting这几个状态之间切换,直到最后变成Terminated状态,线程终止。

线程终止的原因有:

  • 线程正常终止:run()方法执行到return语句返回;
  • 线程意外终止:run()方法因为未捕获的异常导致线程终止;
  • 对某个线程的Thread实例调用stop()方法强制终止(强烈不推荐使用)。

一个线程还可以等待另一个线程直到其运行结束。例如,main线程在启动t线程后,可以通过t.join()等待t线程结束后再继续运行:

1
2
3
4
5
6
7
8
9
10
11
12
// 多线程
public class Main {
public static void main(String[] args) throws InterruptedException {
Thread t = new Thread(() -> {
System.out.println("hello");
});
System.out.println("start");
t.start(); // 启动t线程
t.join(); // 此处main线程会等待t结束
System.out.println("end");
}
}

main线程对线程对象t调用join()方法时,主线程将等待变量t表示的线程运行结束,即join就是指等待该线程结束,然后才继续往下执行自身线程。所以,上述代码打印顺序可以肯定是main线程先打印startt线程再打印hellomain线程最后再打印end

如果t线程已经结束,对实例t调用join()会立刻返回。此外,join(long)的重载方法也可以指定一个等待时间,超过等待时间后就不再继续等待。

小结

Java线程对象Thread的状态包括:NewRunnableBlockedWaitingTimed WaitingTerminated

通过对另一个线程对象调用join()方法可以等待其执行结束;

可以指定等待时间,超过等待时间线程仍然没有结束就不再等待;

对已经运行结束的线程调用join()方法会立刻返回。



如果线程需要执行一个长时间任务,就可能需要能中断线程。中断线程就是其他线程给该线程发一个信号,该线程收到信号后结束执行run()方法,使得自身线程能立刻结束运行。

我们举个栗子:假设从网络下载一个100M的文件,如果网速很慢,用户等得不耐烦,就可能在下载过程中点“取消”,这时,程序就需要中断下载线程的执行。

中断一个线程非常简单,只需要在其他线程中对目标线程调用interrupt()方法,目标线程需要反复检测自身状态是否是interrupted状态,如果是,就立刻结束运行。

我们还是看示例代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
// 中断线程
public class Main {
public static void main(String[] args) throws InterruptedException {
Thread t = new MyThread();
t.start();
Thread.sleep(1); // 暂停1毫秒
t.interrupt(); // 中断t线程
t.join(); // 等待t线程结束
System.out.println("end");
}
}

class MyThread extends Thread {
public void run() {
int n = 0;
while (! isInterrupted()) {
n ++;
System.out.println(n + " hello!");
}
}
}

仔细看上述代码,main线程通过调用t.interrupt()方法中断t线程,但是要注意,interrupt()方法仅仅向t线程发出了“中断请求”,至于t线程是否能立刻响应,要看具体代码。而t线程的while循环会检测isInterrupted(),所以上述代码能正确响应interrupt()请求,使得自身立刻结束运行run()方法。

如果线程处于等待状态,例如,t.join()会让main线程进入等待状态,此时,如果对main线程调用interrupt()join()方法会立刻抛出InterruptedException,因此,目标线程只要捕获到join()方法抛出的InterruptedException,就说明有其他线程对其调用了interrupt()方法,通常情况下该线程应该立刻结束运行。

我们来看下面的示例代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
// 中断线程
public class Main {
public static void main(String[] args) throws InterruptedException {
Thread t = new MyThread();
t.start();
Thread.sleep(1000);
t.interrupt(); // 中断t线程
t.join(); // 等待t线程结束
System.out.println("end");
}
}

class MyThread extends Thread {
public void run() {
Thread hello = new HelloThread();
hello.start(); // 启动hello线程
try {
hello.join(); // 等待hello线程结束
} catch (InterruptedException e) {
System.out.println("interrupted!");
}
hello.interrupt();
}
}

class HelloThread extends Thread {
public void run() {
int n = 0;
while (!isInterrupted()) {
n++;
System.out.println(n + " hello!");
try {
Thread.sleep(100);
} catch (InterruptedException e) {
break;
}
}
}
}

main线程通过调用t.interrupt()从而通知t线程中断,而此时t线程正位于hello.join()的等待中,此方法会立刻结束等待并抛出InterruptedException。由于我们在t线程中捕获了InterruptedException,因此,就可以准备结束该线程。在t线程结束前,对hello线程也进行了interrupt()调用通知其中断。如果去掉这一行代码,可以发现hello线程仍然会继续运行,且JVM不会退出。

另一个常用的中断线程的方法是设置标志位。我们通常会用一个running标志位来标识线程是否应该继续运行,在外部线程中,通过把HelloThread.running置为false,就可以让线程结束:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
// 中断线程
public class Main {
public static void main(String[] args) throws InterruptedException {
HelloThread t = new HelloThread();
t.start();
Thread.sleep(1);
t.running = false; // 标志位置为false
}
}

class HelloThread extends Thread {
public volatile boolean running = true;
public void run() {
int n = 0;
while (running) {
n ++;
System.out.println(n + " hello!");
}
System.out.println("end!");
}
}

注意到HelloThread的标志位boolean running是一个线程间共享的变量。线程间共享变量需要使用volatile关键字标记,确保每个线程都能读取到更新后的变量值。

为什么要对线程间共享的变量用关键字volatile声明?这涉及到Java的内存模型。在Java虚拟机中,变量的值保存在主内存中,但是,当线程访问变量时,它会先获取一个副本,并保存在自己的工作内存中。如果线程修改了变量的值,虚拟机会在某个时刻把修改后的值回写到主内存,但是,这个时间是不确定的!

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐
Main Memory
│ │
┌───────┐┌───────┐┌───────┐
│ │ var A ││ var B ││ var C │ │
└───────┘└───────┘└───────┘
│ │ ▲ │ ▲ │
─ ─ ─│─│─ ─ ─ ─ ─ ─ ─ ─│─│─ ─ ─
│ │ │ │
┌ ─ ─ ┼ ┼ ─ ─ ┐ ┌ ─ ─ ┼ ┼ ─ ─ ┐
▼ │ ▼ │
│ ┌───────┐ │ │ ┌───────┐ │
│ var A │ │ var C │
│ └───────┘ │ │ └───────┘ │
Thread 1 Thread 2
└ ─ ─ ─ ─ ─ ─ ┘ └ ─ ─ ─ ─ ─ ─ ┘

这会导致如果一个线程更新了某个变量,另一个线程读取的值可能还是更新前的。例如,主内存的变量a = true,线程1执行a = false时,它在此刻仅仅是把变量a的副本变成了false,主内存的变量a还是true,在JVM把修改后的a回写到主内存之前,其他线程读取到的a的值仍然是true,这就造成了多线程之间共享的变量不一致。

因此,volatile关键字的目的是告诉虚拟机:

  • 每次访问变量时,总是获取主内存的最新值;
  • 每次修改变量后,立刻回写到主内存。

volatile关键字解决的是可见性问题:当一个线程修改了某个共享变量的值,其他线程能够立刻看到修改后的值。

如果我们去掉volatile关键字,运行上述程序,发现效果和带volatile差不多,这是因为在x86的架构下,JVM回写主内存的速度非常快,但是,换成ARM的架构,就会有显著的延迟。

小结

对目标线程调用interrupt()方法可以请求中断一个线程,目标线程通过检测isInterrupted()标志获取自身是否已中断。如果目标线程处于等待状态,该线程会捕获到InterruptedException

目标线程检测到isInterrupted()true或者捕获了InterruptedException都应该立刻结束自身线程;

通过标志位判断需要正确使用volatile关键字;

volatile关键字解决了共享变量在线程间的可见性问题。

守护线程

Java程序入口就是由JVM启动main线程,main线程又可以启动其他线程。当所有线程都运行结束时,JVM退出,进程结束。

如果有一个线程没有退出,JVM进程就不会退出。所以,必须保证所有线程都能及时结束。

但是有一种线程的目的就是无限循环,例如,一个定时触发任务的线程:

1
2
3
4
5
6
7
8
9
10
11
12
13
class TimerThread extends Thread {
@Override
public void run() {
while (true) {
System.out.println(LocalTime.now());
try {
Thread.sleep(1000);
} catch (InterruptedException e) {
break;
}
}
}
}

如果这个线程不结束,JVM进程就无法结束。问题是,由谁负责结束这个线程?

然而这类线程经常没有负责人来负责结束它们。但是,当其他线程结束时,JVM进程又必须要结束,怎么办?

答案是使用守护线程(Daemon Thread)。

守护线程是指为其他线程服务的线程。在JVM中,所有非守护线程都执行完毕后,无论有没有守护线程,虚拟机都会自动退出。

因此,JVM退出时,不必关心守护线程是否已结束。

如何创建守护线程呢?方法和普通线程一样,只是在调用start()方法前,调用setDaemon(true)把该线程标记为守护线程:

1
2
3
Thread t = new MyThread();
t.setDaemon(true);
t.start();

在守护线程中,编写代码要注意:守护线程不能持有任何需要关闭的资源,例如打开文件等,因为虚拟机退出时,守护线程没有任何机会来关闭文件,这会导致数据丢失。

练习

使用守护线程。

下载练习

小结

守护线程是为其他线程服务的线程;

所有非守护线程都执行完毕后,虚拟机退出,守护线程随之结束;

守护线程不能持有需要关闭的资源(如打开文件等)。



留言與分享

JAVA-加密与安全

分類 编程语言, Java

在计算机系统中,什么是加密与安全呢?

我们举个栗子:假设Bob要给Alice发一封邮件,在邮件传送的过程中,黑客可能会窃取到邮件的内容,所以需要防窃听。黑客还可能会篡改邮件的内容,Alice必须有能力识别出邮件有没有被篡改。最后,黑客可能假冒Bob给Alice发邮件,Alice必须有能力识别出伪造的邮件。

所以,应对潜在的安全威胁,需要做到三防:

  • 防窃听
  • 防篡改
  • 防伪造

java-security

计算机加密技术就是为了实现上述目标,而现代计算机密码学理论是建立在严格的数学理论基础上的,密码学已经逐渐发展成一门科学。对于绝大多数开发者来说,设计一个安全的加密算法非常困难,验证一个加密算法是否安全更加困难,当前被认为安全的加密算法仅仅是迄今为止尚未被攻破。因此,要编写安全的计算机程序,我们要做到:

  • 不要自己设计山寨的加密算法;
  • 不要自己实现已有的加密算法;
  • 不要自己修改已有的加密算法。

本章我们会介绍最常用的加密算法,以及如何通过Java代码实现。

要学习编码算法,我们先来看一看什么是编码。

ASCII码就是一种编码,字母A的编码是十六进制的0x41,字母B0x42,以此类推:

字母 ASCII编码
A 0x41
B 0x42
C 0x43
D 0x44

因为ASCII编码最多只能有128个字符,要想对更多的文字进行编码,就需要用Unicode。而中文的中使用Unicode编码就是0x4e2d,使用UTF-8则需要3个字节编码:

汉字 Unicode编码 UTF-8编码
0x4e2d 0xe4b8ad
0x6587 0xe69687
0x7f16 0xe7bc96
0x7801 0xe7a081

因此,最简单的编码是直接给每个字符指定一个若干字节表示的整数,复杂一点的编码就需要根据一个已有的编码推算出来。

比如UTF-8编码,它是一种不定长编码,但可以从给定字符的Unicode编码推算出来。

URL编码

URL编码是浏览器发送数据给服务器时使用的编码,它通常附加在URL的参数部分,例如:

https://www.baidu.com/s?wd=%E4%B8%AD%E6%96%87

之所以需要URL编码,是因为出于兼容性考虑,很多服务器只识别ASCII字符。但如果URL中包含中文、日文这些非ASCII字符怎么办?不要紧,URL编码有一套规则:

  • 如果字符是A~Za~z0~9以及-_.*,则保持不变;
  • 如果是其他字符,先转换为UTF-8编码,然后对每个字节以%XX表示。

例如:字符的UTF-8编码是0xe4b8ad,因此,它的URL编码是%E4%B8%AD。URL编码总是大写。

Java标准库提供了一个URLEncoder类来对任意字符串进行URL编码:

1
2
3
4
5
6
7
8
9
import java.net.URLEncoder;
import java.nio.charset.StandardCharsets;

public class Main {
public static void main(String[] args) {
String encoded = URLEncoder.encode("中文!", StandardCharsets.UTF_8);
System.out.println(encoded);
}
}

上述代码的运行结果是%E4%B8%AD%E6%96%87%21的URL编码是%E4%B8%AD的URL编码是%E6%96%87!虽然是ASCII字符,也要对其编码为%21

和标准的URL编码稍有不同,URLEncoder把空格字符编码成+,而现在的URL编码标准要求空格被编码为%20,不过,服务器都可以处理这两种情况。

如果服务器收到URL编码的字符串,就可以对其进行解码,还原成原始字符串。Java标准库的URLDecoder就可以解码:

1
2
3
4
5
6
7
8
9
import java.net.URLDecoder;
import java.nio.charset.StandardCharsets;

public class Main {
public static void main(String[] args) {
String decoded = URLDecoder.decode("%E4%B8%AD%E6%96%87%21", StandardCharsets.UTF_8);
System.out.println(decoded);
}
}

要特别注意:URL编码是编码算法,不是加密算法。URL编码的目的是把任意文本数据编码为%前缀表示的文本,编码后的文本仅包含A~Za~z0~9-_.*%,便于浏览器和服务器处理。

Base64编码

URL编码是对字符进行编码,表示成%xx的形式,而Base64编码是对二进制数据进行编码,表示成文本格式。

Base64编码可以把任意长度的二进制数据变为纯文本,且只包含A~Za~z0~9+/=这些字符。它的原理是把3字节的二进制数据按6bit一组,用4个int整数表示,然后查表,把int整数用索引对应到字符,得到编码后的字符串。

举个例子:3个byte数据分别是e4b8ad,按6bit分组得到390b222d

1
2
3
4
5
6
7
8
9
┌───────────────┬───────────────┬───────────────┐
│ e4 │ b8 │ ad │
└───────────────┴───────────────┴───────────────┘
┌─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┐
│1│1│1│0│0│1│0│0│1│0│1│1│1│0│0│0│1│0│1│0│1│1│0│1│
└─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┘
┌───────────┬───────────┬───────────┬───────────┐
│ 39 │ 0b │ 22 │ 2d │
└───────────┴───────────┴───────────┴───────────┘

因为6位整数的范围总是0~63,所以,能用64个字符表示:字符A~Z对应索引0~25,字符a~z对应索引26~51,字符0~9对应索引52~61,最后两个索引6263分别用字符+/表示。

在Java中,二进制数据就是byte[]数组。Java标准库提供了Base64来对byte[]数组进行编解码:

1
2
3
4
5
6
7
8
9
import java.util.*;

public class Main {
public static void main(String[] args) {
byte[] input = new byte[] { (byte) 0xe4, (byte) 0xb8, (byte) 0xad };
String b64encoded = Base64.getEncoder().encodeToString(input);
System.out.println(b64encoded);
}
}

编码后得到5Lit4个字符。要对Base64解码,仍然用Base64这个类:

1
2
3
4
5
6
7
8
import java.util.*;

public class Main {
public static void main(String[] args) {
byte[] output = Base64.getDecoder().decode("5Lit");
System.out.println(Arrays.toString(output)); // [-28, -72, -83]
}
}

有的童鞋会问:如果输入的byte[]数组长度不是3的整数倍肿么办?这种情况下,需要对输入的末尾补一个或两个0x00,编码后,在结尾加一个=表示补充了1个0x00,加两个=表示补充了2个0x00,解码的时候,去掉末尾补充的一个或两个0x00即可。

实际上,因为编码后的长度加上=总是4的倍数,所以即使不加=也可以计算出原始输入的byte[]。Base64编码的时候可以用withoutPadding()去掉=,解码出来的结果是一样的:

1
2
3
4
5
6
7
8
9
10
11
12
13
import java.util.*;

public class Main {
public static void main(String[] args) {
byte[] input = new byte[] { (byte) 0xe4, (byte) 0xb8, (byte) 0xad, 0x21 };
String b64encoded = Base64.getEncoder().encodeToString(input);
String b64encoded2 = Base64.getEncoder().withoutPadding().encodeToString(input);
System.out.println(b64encoded);
System.out.println(b64encoded2);
byte[] output = Base64.getDecoder().decode(b64encoded2);
System.out.println(Arrays.toString(output));
}
}

因为标准的Base64编码会出现+/=,所以不适合把Base64编码后的字符串放到URL中。一种针对URL的Base64编码可以在URL中使用的Base64编码,它仅仅是把+变成-/变成_

1
2
3
4
5
6
7
8
9
10
11
import java.util.*;

public class Main {
public static void main(String[] args) {
byte[] input = new byte[] { 0x01, 0x02, 0x7f, 0x00 };
String b64encoded = Base64.getUrlEncoder().encodeToString(input);
System.out.println(b64encoded);
byte[] output = Base64.getUrlDecoder().decode(b64encoded);
System.out.println(Arrays.toString(output));
}
}

Base64编码的目的是把二进制数据变成文本格式,这样在很多文本中就可以处理二进制数据。例如,电子邮件协议就是文本协议,如果要在电子邮件中添加一个二进制文件,就可以用Base64编码,然后以文本的形式传送。

Base64编码的缺点是传输效率会降低,因为它把原始数据的长度增加了1/3。

和URL编码一样,Base64编码是一种编码算法,不是加密算法。

如果把Base64的64个字符编码表换成32个、48个或者58个,就可以使用Base32编码,Base48编码和Base58编码。字符越少,编码的效率就会越低。

小结

URL编码和Base64编码都是编码算法,它们不是加密算法;

URL编码的目的是把任意文本数据编码为%前缀表示的文本,便于浏览器和服务器处理;

Base64编码的目的是把任意二进制数据编码为文本,但编码后数据量会增加1/3。

哈希算法(Hash)又称摘要算法(Digest),它的作用是:对任意一组输入数据进行计算,得到一个固定长度的输出摘要。

哈希算法最重要的特点就是:

  • 相同的输入一定得到相同的输出;
  • 不同的输入大概率得到不同的输出。

哈希算法的目的就是为了验证原始数据是否被篡改。

Java字符串的hashCode()就是一个哈希算法,它的输入是任意字符串,输出是固定的4字节int整数:

1
2
3
"hello".hashCode(); // 0x5e918d2
"hello, java".hashCode(); // 0x7a9d88e8
"hello, bob".hashCode(); // 0xa0dbae2f

两个相同的字符串永远会计算出相同的hashCode,否则基于hashCode定位的HashMap就无法正常工作。这也是为什么当我们自定义一个class时,覆写equals()方法时我们必须正确覆写hashCode()方法。

哈希碰撞

哈希碰撞是指,两个不同的输入得到了相同的输出:

1
2
"AaAaAa".hashCode(); // 0x7460e8c0
"BBAaBB".hashCode(); // 0x7460e8c0

有童鞋会问:碰撞能不能避免?答案是不能。碰撞是一定会出现的,因为输出的字节长度是固定的,StringhashCode()输出是4字节整数,最多只有4294967296种输出,但输入的数据长度是不固定的,有无数种输入。所以,哈希算法是把一个无限的输入集合映射到一个有限的输出集合,必然会产生碰撞。

碰撞不可怕,我们担心的不是碰撞,而是碰撞的概率,因为碰撞概率的高低关系到哈希算法的安全性。一个安全的哈希算法必须满足:

  • 碰撞概率低;
  • 不能猜测输出。

不能猜测输出是指,输入的任意一个bit的变化会造成输出完全不同,这样就很难从输出反推输入(只能依靠暴力穷举)。假设一种哈希算法有如下规律:

1
2
3
hashA("java001") = "123456"
hashA("java002") = "123457"
hashA("java003") = "123458"

那么很容易从输出123459反推输入,这种哈希算法就不安全。安全的哈希算法从输出是看不出任何规律的:

1
2
3
hashB("java001") = "123456"
hashB("java002") = "580271"
hashB("java003") = ???

常用的哈希算法有:

算法 输出长度(位) 输出长度(字节)
MD5 128 bits 16 bytes
SHA-1 160 bits 20 bytes
RipeMD-160 160 bits 20 bytes
SHA-256 256 bits 32 bytes
SHA-512 512 bits 64 bytes

根据碰撞概率,哈希算法的输出长度越长,就越难产生碰撞,也就越安全。

Java标准库提供了常用的哈希算法,并且有一套统一的接口。我们以MD5算法为例,看看如何对输入计算哈希:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
import java.security.MessageDigest;
import java.util.HexFormat;

public class Main {
public static void main(String[] args) throws Exception {
// 创建一个MessageDigest实例:
MessageDigest md = MessageDigest.getInstance("MD5");
// 反复调用update输入数据:
md.update("Hello".getBytes("UTF-8"));
md.update("World".getBytes("UTF-8"));
byte[] result = md.digest(); // 16 bytes: 68e109f0f40ca72a15e05cc22786f8e6
System.out.println(HexFormat.of().formatHex(result));
}
}

使用MessageDigest时,我们首先根据哈希算法获取一个MessageDigest实例,然后,反复调用update(byte[])输入数据。当输入结束后,调用digest()方法获得byte[]数组表示的摘要,最后,把它转换为十六进制的字符串。

运行上述代码,可以得到输入HelloWorld的MD5是68e109f0f40ca72a15e05cc22786f8e6

哈希算法的用途

因为相同的输入永远会得到相同的输出,因此,如果输入被修改了,得到的输出就会不同。

我们在网站上下载软件的时候,经常看到下载页显示的哈希:

file-md5

如何判断下载到本地的软件是原始的、未经篡改的文件?我们只需要自己计算一下本地文件的哈希值,再与官网公开的哈希值对比,如果相同,说明文件下载正确,否则,说明文件已被篡改。

哈希算法的另一个重要用途是存储用户口令。如果直接将用户的原始口令存放到数据库中,会产生极大的安全风险:

  • 数据库管理员能够看到用户明文口令;
  • 数据库数据一旦泄漏,黑客即可获取用户明文口令。

不存储用户的原始口令,那么如何对用户进行认证?

方法是存储用户口令的哈希,例如,MD5。

在用户输入原始口令后,系统计算用户输入的原始口令的MD5并与数据库存储的MD5对比,如果一致,说明口令正确,否则,口令错误。

因此,数据库存储用户名和口令的表内容应该像下面这样:

username password
bob f30aa7a662c728b7407c54ae6bfd27d1
alice 25d55ad283aa400af464c76d713c07ad
tim bed128365216c019988915ed3add75fb

这样一来,数据库管理员看不到用户的原始口令。即使数据库泄漏,黑客也无法拿到用户的原始口令。想要拿到用户的原始口令,必须用暴力穷举的方法,一个口令一个口令地试,直到某个口令计算的MD5恰好等于指定值。

使用哈希口令时,还要注意防止彩虹表攻击。

什么是彩虹表?难道是这个:

彩虹表

上面讲到了,如果只拿到MD5,从MD5反推明文口令,只能使用暴力穷举的方法。

然而黑客并不笨,暴力穷举会消耗大量的算力和时间。但是,如果有一个预先计算好的常用口令和它们的MD5的对照表:

常用口令 MD5
hello123 f30aa7a662c728b7407c54ae6bfd27d1
12345678 25d55ad283aa400af464c76d713c07ad
passw0rd bed128365216c019988915ed3add75fb
19700101 570da6d5277a646f6552b8832012f5dc
20201231 6879c0ae9117b50074ce0a0d4c843060

这个表就是彩虹表。如果用户使用了常用口令,黑客从MD5一下就能反查到原始口令:

bob的MD5:f30aa7a662c728b7407c54ae6bfd27d1,原始口令:hello123

alice的MD5:25d55ad283aa400af464c76d713c07ad,原始口令:12345678

tim的MD5:bed128365216c019988915ed3add75fb,原始口令:passw0rd

这就是为什么不要使用常用密码,以及不要使用生日作为密码的原因。

即使用户使用了常用口令,我们也可以采取措施来抵御彩虹表攻击,方法是对每个口令额外添加随机数,这个方法称之为加盐(salt):

1
digest = md5(salt+inputPassword)

经过加盐处理的数据库表,内容如下:

username salt password
bob H1r0a a5022319ff4c56955e22a74abcc2c210
alice 7$p2w e5de688c99e961ed6e560b972dab8b6a
tim z5Sk9 1eee304b92dc0d105904e7ab58fd2f64

加盐的目的在于使黑客的彩虹表失效,即使用户使用常用口令,也无法从MD5反推原始口令。

SHA-1

SHA-1也是一种哈希算法,它的输出是160 bits,即20字节。SHA-1是由美国国家安全局开发的,SHA算法实际上是一个系列,包括SHA-0(已废弃)、SHA-1、SHA-256、SHA-512等。

在Java中使用SHA-1,和MD5完全一样,只需要把算法名称改为"SHA-1"

1
2
3
4
5
6
7
8
9
10
11
12
13
14
import java.security.MessageDigest;
import java.util.HexFormat;

public class Main {
public static void main(String[] args) throws Exception {
// 创建一个MessageDigest实例:
MessageDigest md = MessageDigest.getInstance("SHA-1");
// 反复调用update输入数据:
md.update("Hello".getBytes("UTF-8"));
md.update("World".getBytes("UTF-8"));
byte[] result = md.digest(); // 20 bytes: db8ac1c259eb89d4a131b253bacfca5f319d54f2
System.out.println(HexFormat.of().formatHex(result));
}
}

类似的,计算SHA-256,我们需要传入名称"SHA-256",计算SHA-512,我们需要传入名称"SHA-512"。Java标准库支持的所有哈希算法可以在这里查到。

注意

MD5因为输出长度较短,短时间内破解是可能的,目前已经不推荐使用。

小结

哈希算法可用于验证数据完整性,具有防篡改检测的功能;

常用的哈希算法有MD5、SHA-1等;

用哈希存储口令时要考虑彩虹表攻击。

BouncyCastle

我们知道,Java标准库提供了一系列常用的哈希算法。

但如果我们要用的某种算法,Java标准库没有提供怎么办?

方法一:自己写一个,难度很大;

方法二:找一个现成的第三方库,直接使用。

BouncyCastle就是一个提供了很多哈希算法和加密算法的第三方库。它提供了Java标准库没有的一些算法,例如,RipeMD160哈希算法。

我们来看一下如何使用BouncyCastle这个第三方提供的算法。

首先,我们必须把BouncyCastle提供的jar包放到classpath中。这个jar包就是bcprov-jdk18on-xxx.jar,可以从官方网站下载。

Java标准库的java.security包提供了一种标准机制,允许第三方提供商无缝接入。我们要使用BouncyCastle提供的RipeMD160算法,需要先把BouncyCastle注册一下:

1
2
3
4
5
6
7
8
9
10
11
public class Main {
public static void main(String[] args) throws Exception {
// 注册BouncyCastle:
Security.addProvider(new BouncyCastleProvider());
// 按名称正常调用:
MessageDigest md = MessageDigest.getInstance("RipeMD160");
md.update("HelloWorld".getBytes("UTF-8"));
byte[] result = md.digest();
System.out.println(HexFormat.of().formatHex(result));
}
}

其中,注册BouncyCastle是通过下面的语句实现的:

1
Security.addProvider(new BouncyCastleProvider());

注册只需要在启动时进行一次,后续就可以使用BouncyCastle提供的所有哈希算法和加密算法。

练习

使用BouncyCastle提供的RipeMD160计算哈希。

下载练习

小结

BouncyCastle是一个开源的第三方算法提供商;

BouncyCastle提供了很多Java标准库没有提供的哈希算法和加密算法;

使用第三方算法前需要通过Security.addProvider()注册。



Hmac算法

在前面讲到哈希算法时,我们说,存储用户的哈希口令时,要加盐存储,目的就在于抵御彩虹表攻击。

我们回顾一下哈希算法:

1
digest = hash(input)

正是因为相同的输入会产生相同的输出,我们加盐的目的就在于,使得输入有所变化:

1
digest = hash(salt + input)

这个salt可以看作是一个额外的“认证码”,同样的输入,不同的认证码,会产生不同的输出。因此,要验证输出的哈希,必须同时提供“认证码”。

Hmac算法就是一种基于密钥的消息认证码算法,它的全称是Hash-based Message Authentication Code,是一种更安全的消息摘要算法。

Hmac算法总是和某种哈希算法配合起来用的。例如,我们使用MD5算法,对应的就是HmacMD5算法,它相当于“加盐”的MD5:

1
HmacMD5 ≈ md5(secure_random_key, input)

因此,HmacMD5可以看作带有一个安全的key的MD5。使用HmacMD5而不是用MD5加salt,有如下好处:

  • HmacMD5使用的key长度是64字节,更安全;
  • Hmac是标准算法,同样适用于SHA-1等其他哈希算法;
  • Hmac输出和原有的哈希算法长度一致。

可见,Hmac本质上就是把key混入摘要的算法。验证此哈希时,除了原始的输入数据,还要提供key。

为了保证安全,我们不会自己指定key,而是通过Java标准库的KeyGenerator生成一个安全的随机的key。下面是使用HmacMD5的代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import javax.crypto.*;
import java.util.HexFormat;

public class Main {
public static void main(String[] args) throws Exception {
KeyGenerator keyGen = KeyGenerator.getInstance("HmacMD5");
SecretKey key = keyGen.generateKey();
// 打印随机生成的key:
byte[] skey = key.getEncoded();
System.out.println(HexFormat.of().formatHex(skey));
Mac mac = Mac.getInstance("HmacMD5");
mac.init(key);
mac.update("HelloWorld".getBytes("UTF-8"));
byte[] result = mac.doFinal();
System.out.println(HexFormat.of().formatHex(result));
}
}

和MD5相比,使用HmacMD5的步骤是:

  1. 通过名称HmacMD5获取KeyGenerator实例;
  2. 通过KeyGenerator创建一个SecretKey实例;
  3. 通过名称HmacMD5获取Mac实例;
  4. SecretKey初始化Mac实例;
  5. Mac实例反复调用update(byte[])输入数据;
  6. 调用Mac实例的doFinal()获取最终的哈希值。

我们可以用Hmac算法取代原有的自定义的加盐算法,因此,存储用户名和口令的数据库结构如下:

username secret_key (64 bytes) password
bob a8c06e05f92e…5e16 7e0387872a57c85ef6dddbaa12f376de
alice e6a343693985…f4be c1f929ac2552642b302e739bc0cdbaac
tim f27a973dfdc0…6003 af57651c3a8a73303515804d4af43790

有了Hmac计算的哈希和SecretKey,我们想要验证怎么办?这时,SecretKey不能从KeyGenerator生成,而是从一个byte[]数组恢复:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import javax.crypto.*;
import javax.crypto.spec.*;
import java.util.HexFormat;

public class Main {
public static void main(String[] args) throws Exception {
byte[] hkey = HexFormat.of().parseHex(
"b648ee779d658c420420d86291ec70f5" +
"cf97521c740330972697a8fad0b55f5c" +
"5a7924e4afa99d8c5883e07d7c3f9ed0" +
"76aa544d25ed2f5ceea59dcc122babc8");
SecretKey key = new SecretKeySpec(hkey, "HmacMD5");
Mac mac = Mac.getInstance("HmacMD5");
mac.init(key);
mac.update("HelloWorld".getBytes("UTF-8"));
byte[] result = mac.doFinal();
System.out.println(HexFormat.of().formatHex(result)); // 4af40be7864efaae1473a4c601b650ae
}
}

恢复SecretKey的语句就是new SecretKeySpec(hkey, "HmacMD5")

小结

Hmac算法是一种标准的基于密钥的哈希算法,可以配合MD5、SHA-1等哈希算法,计算的摘要长度和原摘要算法长度相同。



对称加密算法就是传统的用一个密码进行加密和解密。例如,我们常用的WinZIP和WinRAR对压缩包的加密和解密,就是使用对称加密算法:

winrar

从程序的角度看,所谓加密,就是这样一个函数,它接收密码和明文,然后输出密文:

1
secret = encrypt(key, message);

而解密则相反,它接收密码和密文,然后输出明文:

1
plain = decrypt(key, secret);

在软件开发中,常用的对称加密算法有:

算法 密钥长度 工作模式 填充模式
DES 56/64 ECB/CBC/PCBC/CTR/… NoPadding/PKCS5Padding/…
AES 128/192/256 ECB/CBC/PCBC/CTR/… NoPadding/PKCS5Padding/PKCS7Padding/…
IDEA 128 ECB PKCS5Padding/PKCS7Padding/…

密钥长度直接决定加密强度,而工作模式和填充模式可以看成是对称加密算法的参数和格式选择。Java标准库提供的算法实现并不包括所有的工作模式和所有填充模式,但是通常我们只需要挑选常用的使用就可以了。

最后注意,DES算法由于密钥过短,可以在短时间内被暴力破解,所以现在已经不安全了。

使用AES加密

AES算法是目前应用最广泛的加密算法。我们先用ECB模式加密并解密:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import java.security.*;
import java.util.Base64;

import javax.crypto.*;
import javax.crypto.spec.*;

public class Main {
public static void main(String[] args) throws Exception {
// 原文:
String message = "Hello, world!";
System.out.println("Message: " + message);
// 128位密钥 = 16 bytes Key:
byte[] key = "1234567890abcdef".getBytes("UTF-8");
// 加密:
byte[] data = message.getBytes("UTF-8");
byte[] encrypted = encrypt(key, data);
System.out.println("Encrypted: " + Base64.getEncoder().encodeToString(encrypted));
// 解密:
byte[] decrypted = decrypt(key, encrypted);
System.out.println("Decrypted: " + new String(decrypted, "UTF-8"));
}

// 加密:
public static byte[] encrypt(byte[] key, byte[] input) throws GeneralSecurityException {
Cipher cipher = Cipher.getInstance("AES/ECB/PKCS5Padding");
SecretKey keySpec = new SecretKeySpec(key, "AES");
cipher.init(Cipher.ENCRYPT_MODE, keySpec);
return cipher.doFinal(input);
}

// 解密:
public static byte[] decrypt(byte[] key, byte[] input) throws GeneralSecurityException {
Cipher cipher = Cipher.getInstance("AES/ECB/PKCS5Padding");
SecretKey keySpec = new SecretKeySpec(key, "AES");
cipher.init(Cipher.DECRYPT_MODE, keySpec);
return cipher.doFinal(input);
}
}

Java标准库提供的对称加密接口非常简单,使用时按以下步骤编写代码:

  1. 根据算法名称/工作模式/填充模式获取Cipher实例;
  2. 根据算法名称初始化一个SecretKey实例,密钥必须是指定长度;
  3. 使用SecretKey初始化Cipher实例,并设置加密或解密模式;
  4. 传入明文或密文,获得密文或明文。

ECB模式是最简单的AES加密模式,它只需要一个固定长度的密钥,固定的明文会生成固定的密文,这种一对一的加密方式会导致安全性降低,更好的方式是通过CBC模式,它需要一个随机数作为IV参数,这样对于同一份明文,每次生成的密文都不同:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import java.security.*;
import java.util.Base64;
import javax.crypto.*;
import javax.crypto.spec.*;

public class Main {
public static void main(String[] args) throws Exception {
// 原文:
String message = "Hello, world!";
System.out.println("Message: " + message);
// 256位密钥 = 32 bytes Key:
byte[] key = "1234567890abcdef1234567890abcdef".getBytes("UTF-8");
// 加密:
byte[] data = message.getBytes("UTF-8");
byte[] encrypted = encrypt(key, data);
System.out.println("Encrypted: " + Base64.getEncoder().encodeToString(encrypted));
// 解密:
byte[] decrypted = decrypt(key, encrypted);
System.out.println("Decrypted: " + new String(decrypted, "UTF-8"));
}

// 加密:
public static byte[] encrypt(byte[] key, byte[] input) throws GeneralSecurityException {
Cipher cipher = Cipher.getInstance("AES/CBC/PKCS5Padding");
SecretKeySpec keySpec = new SecretKeySpec(key, "AES");
// CBC模式需要生成一个16 bytes的initialization vector:
SecureRandom sr = SecureRandom.getInstanceStrong();
byte[] iv = sr.generateSeed(16);
IvParameterSpec ivps = new IvParameterSpec(iv);
cipher.init(Cipher.ENCRYPT_MODE, keySpec, ivps);
byte[] data = cipher.doFinal(input);
// IV不需要保密,把IV和密文一起返回:
return join(iv, data);
}

// 解密:
public static byte[] decrypt(byte[] key, byte[] input) throws GeneralSecurityException {
// 把input分割成IV和密文:
byte[] iv = new byte[16];
byte[] data = new byte[input.length - 16];
System.arraycopy(input, 0, iv, 0, 16);
System.arraycopy(input, 16, data, 0, data.length);
// 解密:
Cipher cipher = Cipher.getInstance("AES/CBC/PKCS5Padding");
SecretKeySpec keySpec = new SecretKeySpec(key, "AES");
IvParameterSpec ivps = new IvParameterSpec(iv);
cipher.init(Cipher.DECRYPT_MODE, keySpec, ivps);
return cipher.doFinal(data);
}

public static byte[] join(byte[] bs1, byte[] bs2) {
byte[] r = new byte[bs1.length + bs2.length];
System.arraycopy(bs1, 0, r, 0, bs1.length);
System.arraycopy(bs2, 0, r, bs1.length, bs2.length);
return r;
}
}

在CBC模式下,需要一个随机生成的16字节IV参数,必须使用SecureRandom生成。因为多了一个IvParameterSpec实例,因此,初始化方法需要调用Cipher的一个重载方法并传入IvParameterSpec

观察输出,可以发现每次生成的IV不同,密文也不同。

小结

对称加密算法使用同一个密钥进行加密和解密,常用算法有DES、AES和IDEA等;

密钥长度由算法设计决定,AES的密钥长度是128/192/256位;

使用对称加密算法需要指定算法名称、工作模式和填充模式。

上一节我们讲的AES加密,细心的童鞋可能会发现,密钥长度是固定的128/192/256位,而不是我们用WinZip/WinRAR那样,随便输入几位都可以。

这是因为对称加密算法决定了口令必须是固定长度,然后对明文进行分块加密。又因为安全需求,口令长度往往都是128位以上,即至少16个字符。

但是我们平时使用的加密软件,输入6位、8位都可以,难道加密方式不一样?

实际上用户输入的口令并不能直接作为AES的密钥进行加密(除非长度恰好是128/192/256位),并且用户输入的口令一般都有规律,安全性远远不如安全随机数产生的随机口令。因此,用户输入的口令,通常还需要使用PBE算法,采用随机数杂凑计算出真正的密钥,再进行加密。

PBE就是Password Based Encryption的缩写,它的作用如下:

1
key = generate(userPassword, secureRandomPassword);

PBE的作用就是把用户输入的口令和一个安全随机的口令采用杂凑后计算出真正的密钥。以AES密钥为例,我们让用户输入一个口令,然后生成一个随机数,通过PBE算法计算出真正的AES口令,再进行加密,代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
public class Main {
public static void main(String[] args) throws Exception {
// 把BouncyCastle作为Provider添加到java.security:
Security.addProvider(new BouncyCastleProvider());
// 原文:
String message = "Hello, world!";
// 加密口令:
String password = "hello12345";
// 16 bytes随机Salt:
byte[] salt = SecureRandom.getInstanceStrong().generateSeed(16);
System.out.println(HexFormat.of().formatHex(salt));
// 加密:
byte[] data = message.getBytes("UTF-8");
byte[] encrypted = encrypt(password, salt, data);
System.out.println("encrypted: " + Base64.getEncoder().encodeToString(encrypted));
// 解密:
byte[] decrypted = decrypt(password, salt, encrypted);
System.out.println("decrypted: " + new String(decrypted, "UTF-8"));
}

// 加密:
public static byte[] encrypt(String password, byte[] salt, byte[] input) throws GeneralSecurityException {
PBEKeySpec keySpec = new PBEKeySpec(password.toCharArray());
SecretKeyFactory skeyFactory = SecretKeyFactory.getInstance("PBEwithSHA1and128bitAES-CBC-BC");
SecretKey skey = skeyFactory.generateSecret(keySpec);
PBEParameterSpec pbeps = new PBEParameterSpec(salt, 1000);
Cipher cipher = Cipher.getInstance("PBEwithSHA1and128bitAES-CBC-BC");
cipher.init(Cipher.ENCRYPT_MODE, skey, pbeps);
return cipher.doFinal(input);
}

// 解密:
public static byte[] decrypt(String password, byte[] salt, byte[] input) throws GeneralSecurityException {
PBEKeySpec keySpec = new PBEKeySpec(password.toCharArray());
SecretKeyFactory skeyFactory = SecretKeyFactory.getInstance("PBEwithSHA1and128bitAES-CBC-BC");
SecretKey skey = skeyFactory.generateSecret(keySpec);
PBEParameterSpec pbeps = new PBEParameterSpec(salt, 1000);
Cipher cipher = Cipher.getInstance("PBEwithSHA1and128bitAES-CBC-BC");
cipher.init(Cipher.DECRYPT_MODE, skey, pbeps);
return cipher.doFinal(input);
}
}

使用PBE时,我们还需要引入BouncyCastle,并指定算法是PBEwithSHA1and128bitAES-CBC-BC。观察代码,实际上真正的AES密钥是调用Cipherinit()方法时同时传入SecretKeyPBEParameterSpec实现的。在创建PBEParameterSpec的时候,我们还指定了循环次数1000,循环次数越多,暴力破解需要的计算量就越大。

如果我们把salt和循环次数固定,就得到了一个通用的“口令”加密软件。如果我们把随机生成的salt存储在U盘,就得到了一个“口令”加USB Key的加密软件,它的好处在于,即使用户使用了一个非常弱的口令,没有USB Key仍然无法解密,因为USB Key存储的随机数密钥安全性非常高。

小结

PBE算法通过用户口令和安全的随机salt计算出Key,然后再进行加密;

Key通过口令和安全的随机salt计算得出,大大提高了安全性;

PBE算法内部使用的仍然是标准对称加密算法(例如AES)。

密钥交换算法

对称加密算法解决了数据加密的问题。我们以AES加密为例,在现实世界中,小明要向路人甲发送一个加密文件,他可以先生成一个AES密钥,对文件进行加密,然后把加密文件发送给对方。因为对方要解密,就必须需要小明生成的密钥。

现在问题来了:如何传递密钥?

在不安全的信道上传递加密文件是没有问题的,因为黑客拿到加密文件没有用。但是,如何如何在不安全的信道上安全地传输密钥?

要解决这个问题,密钥交换算法即DH算法:Diffie-Hellman算法应运而生。

DH算法解决了密钥在双方不直接传递密钥的情况下完成密钥交换,这个神奇的交换原理完全由数学理论支持。

我们来看DH算法交换密钥的步骤。假设甲乙双方需要传递密钥,他们之间可以这么做:

  1. 甲首先选择一个素数p,例如97,底数gp的一个原根,例如5,随机数a,例如123,然后计算A=g^a mod p,结果是34,然后,甲发送p=97g=5A=34给乙;
  2. 乙方收到后,也选择一个随机数b,例如,456,然后计算B = g^b mod p,结果是75,乙再同时计算s = A^b mod p,结果是22;
  3. 乙把计算的B=75发给甲,甲计算s = B^a mod p,计算结果与乙算出的结果一样,都是22。

所以最终双方协商出的密钥s是22。注意到这个密钥s并没有在网络上传输。而通过网络传输的pgAB是无法推算出s的,因为实际算法选择的素数是非常大的。

所以,更确切地说,DH算法是一个密钥协商算法,双方最终协商出一个共同的密钥,而这个密钥不会通过网络传输。

如果我们把a看成甲的私钥,A看成甲的公钥,b看成乙的私钥,B看成乙的公钥,DH算法的本质就是双方各自生成自己的私钥和公钥,私钥仅对自己可见,然后交换公钥,并根据自己的私钥和对方的公钥,生成最终的密钥secretKey,DH算法通过数学定律保证了双方各自计算出的secretKey是相同的。

使用Java实现DH算法的代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import java.security.*;
import java.security.spec.*;
import javax.crypto.KeyAgreement;
import java.util.HexFormat;

public class Main {
public static void main(String[] args) {
// Bob和Alice:
Person bob = new Person("Bob");
Person alice = new Person("Alice");

// 各自生成KeyPair:
bob.generateKeyPair();
alice.generateKeyPair();

// 双方交换各自的PublicKey:
// Bob根据Alice的PublicKey生成自己的本地密钥:
bob.generateSecretKey(alice.publicKey.getEncoded());
// Alice根据Bob的PublicKey生成自己的本地密钥:
alice.generateSecretKey(bob.publicKey.getEncoded());

// 检查双方的本地密钥是否相同:
bob.printKeys();
alice.printKeys();
// 双方的SecretKey相同,后续通信将使用SecretKey作为密钥进行AES加解密...
}
}

class Person {
public final String name;

public PublicKey publicKey;
private PrivateKey privateKey;
private byte[] secretKey;

public Person(String name) {
this.name = name;
}

// 生成本地KeyPair:
public void generateKeyPair() {
try {
KeyPairGenerator kpGen = KeyPairGenerator.getInstance("DH");
kpGen.initialize(512);
KeyPair kp = kpGen.generateKeyPair();
this.privateKey = kp.getPrivate();
this.publicKey = kp.getPublic();
} catch (GeneralSecurityException e) {
throw new RuntimeException(e);
}
}

public void generateSecretKey(byte[] receivedPubKeyBytes) {
try {
// 从byte[]恢复PublicKey:
X509EncodedKeySpec keySpec = new X509EncodedKeySpec(receivedPubKeyBytes);
KeyFactory kf = KeyFactory.getInstance("DH");
PublicKey receivedPublicKey = kf.generatePublic(keySpec);
// 生成本地密钥:
KeyAgreement keyAgreement = KeyAgreement.getInstance("DH");
keyAgreement.init(this.privateKey); // 自己的PrivateKey
keyAgreement.doPhase(receivedPublicKey, true); // 对方的PublicKey
// 生成SecretKey密钥:
this.secretKey = keyAgreement.generateSecret();
} catch (GeneralSecurityException e) {
throw new RuntimeException(e);
}
}

public void printKeys() {
System.out.println("Name: " + this.name);
System.out.println("Private key: " + HexFormat.of().formatHex(this.privateKey.getEncoded()));
System.out.println("Public key: " + HexFormat.of().formatHex(this.publicKey.getEncoded()));
System.out.println("Secret key: " + HexFormat.of().formatHex(this.secretKey));
}
}

但是DH算法并未解决中间人攻击,即甲乙双方并不能确保与自己通信的是否真的是对方。消除中间人攻击需要其他方法。

练习

使用DH算法生成密钥。

下载练习

小结

DH算法是一种密钥交换协议,通信双方通过不安全的信道协商密钥,然后进行对称加密传输。

DH算法没有解决中间人攻击。



从DH算法我们可以看到,公钥-私钥组成的密钥对是非常有用的加密方式,因为公钥是可以公开的,而私钥是完全保密的,由此奠定了非对称加密的基础。

非对称加密就是加密和解密使用的不是相同的密钥:只有同一个公钥-私钥对才能正常加解密。

因此,如果小明要加密一个文件发送给小红,他应该首先向小红索取她的公钥,然后,他用小红的公钥加密,把加密文件发送给小红,此文件只能由小红的私钥解开,因为小红的私钥在她自己手里,所以,除了小红,没有任何人能解开此文件。

非对称加密的典型算法就是RSA算法,它是由Ron Rivest,Adi Shamir,Leonard Adleman这三个哥们一起发明的,所以用他们仨的姓的首字母缩写表示。

非对称加密相比对称加密的显著优点在于,对称加密需要协商密钥,而非对称加密可以安全地公开各自的公钥,在N个人之间通信的时候:使用非对称加密只需要N个密钥对,每个人只管理自己的密钥对。而使用对称加密需要则需要N*(N-1)/2个密钥,因此每个人需要管理N-1个密钥,密钥管理难度大,而且非常容易泄漏。

既然非对称加密这么好,那我们抛弃对称加密,完全使用非对称加密行不行?也不行。因为非对称加密的缺点就是运算速度非常慢,比对称加密要慢很多。

所以,在实际应用的时候,非对称加密总是和对称加密一起使用。假设小明需要给小红需要传输加密文件,他俩首先交换了各自的公钥,然后:

  1. 小明生成一个随机的AES口令,然后用小红的公钥通过RSA加密这个口令,并发给小红;
  2. 小红用自己的RSA私钥解密得到AES口令;
  3. 双方使用这个共享的AES口令用AES加密通信。

可见非对称加密实际上应用在第一步,即加密“AES口令”。这也是我们在浏览器中常用的HTTPS协议的做法,即浏览器和服务器先通过RSA交换AES口令,接下来双方通信实际上采用的是速度较快的AES对称加密,而不是缓慢的RSA非对称加密。

Java标准库提供了RSA算法的实现,示例代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import java.security.*;
import javax.crypto.Cipher;
import java.util.HexFormat;

public class Main {
public static void main(String[] args) throws Exception {
// 明文:
byte[] plain = "Hello, encrypt use RSA".getBytes("UTF-8");
// 创建公钥/私钥对:
Person alice = new Person("Alice");
// 用Alice的公钥加密:
byte[] pk = alice.getPublicKey();
System.out.println("public key: " + HexFormat.of().formatHex(pk));
byte[] encrypted = alice.encrypt(plain);
System.out.println("encrypted: " + HexFormat.of().formatHex(encrypted));
// 用Alice的私钥解密:
byte[] sk = alice.getPrivateKey();
System.out.println("private key: " + HexFormat.of().formatHex(sk));
byte[] decrypted = alice.decrypt(encrypted);
System.out.println(new String(decrypted, "UTF-8"));
}
}

class Person {
String name;
// 私钥:
PrivateKey sk;
// 公钥:
PublicKey pk;

public Person(String name) throws GeneralSecurityException {
this.name = name;
// 生成公钥/私钥对:
KeyPairGenerator kpGen = KeyPairGenerator.getInstance("RSA");
kpGen.initialize(1024);
KeyPair kp = kpGen.generateKeyPair();
this.sk = kp.getPrivate();
this.pk = kp.getPublic();
}

// 把私钥导出为字节
public byte[] getPrivateKey() {
return this.sk.getEncoded();
}

// 把公钥导出为字节
public byte[] getPublicKey() {
return this.pk.getEncoded();
}

// 用公钥加密:
public byte[] encrypt(byte[] message) throws GeneralSecurityException {
Cipher cipher = Cipher.getInstance("RSA");
cipher.init(Cipher.ENCRYPT_MODE, this.pk);
return cipher.doFinal(message);
}

// 用私钥解密:
public byte[] decrypt(byte[] input) throws GeneralSecurityException {
Cipher cipher = Cipher.getInstance("RSA");
cipher.init(Cipher.DECRYPT_MODE, this.sk);
return cipher.doFinal(input);
}
}

RSA的公钥和私钥都可以通过getEncoded()方法获得以byte[]表示的二进制数据,并根据需要保存到文件中。要从byte[]数组恢复公钥或私钥,可以这么写:

1
2
3
4
5
6
7
8
9
byte[] pkData = ...
byte[] skData = ...
KeyFactory kf = KeyFactory.getInstance("RSA");
// 恢复公钥:
X509EncodedKeySpec pkSpec = new X509EncodedKeySpec(pkData);
PublicKey pk = kf.generatePublic(pkSpec);
// 恢复私钥:
PKCS8EncodedKeySpec skSpec = new PKCS8EncodedKeySpec(skData);
PrivateKey sk = kf.generatePrivate(skSpec);

以RSA算法为例,它的密钥有256/512/1024/2048/4096等不同的长度。长度越长,密码强度越大,当然计算速度也越慢。

如果修改待加密的byte[]数据的大小,可以发现,使用512bit的RSA加密时,明文长度不能超过53字节,使用1024bit的RSA加密时,明文长度不能超过117字节,这也是为什么使用RSA的时候,总是配合AES一起使用,即用AES加密任意长度的明文,用RSA加密AES口令。

此外,只使用非对称加密算法不能防止中间人攻击。

练习

使用RSA算法加密。

下载练习

小结

非对称加密就是加密和解密使用的不是相同的密钥,只有同一个公钥-私钥对才能正常加解密;

只使用非对称加密算法不能防止中间人攻击。

签名算法

我们使用非对称加密算法的时候,对于一个公钥-私钥对,通常是用公钥加密,私钥解密。

如果使用私钥加密,公钥解密是否可行呢?实际上是完全可行的。

不过我们再仔细想一想,私钥是保密的,而公钥是公开的,用私钥加密,那相当于所有人都可以用公钥解密。这个加密有什么意义?

这个加密的意义在于,如果小明用自己的私钥加密了一条消息,比如小明喜欢小红,然后他公开了加密消息,由于任何人都可以用小明的公钥解密,从而使得任何人都可以确认小明喜欢小红这条消息肯定是小明发出的,其他人不能伪造这个消息,小明也不能抵赖这条消息不是自己写的。

因此,私钥加密得到的密文实际上就是数字签名,要验证这个签名是否正确,只能用私钥持有者的公钥进行解密验证。使用数字签名的目的是为了确认某个信息确实是由某个发送方发送的,任何人都不可能伪造消息,并且,发送方也不能抵赖。

在实际应用的时候,签名实际上并不是针对原始消息,而是针对原始消息的哈希进行签名,即:

1
signature = encrypt(privateKey, sha256(message))

对签名进行验证实际上就是用公钥解密:

1
hash = decrypt(publicKey, signature)

然后把解密后的哈希与原始消息的哈希进行对比。

因为用户总是使用自己的私钥进行签名,所以,私钥就相当于用户身份。而公钥用来给外部验证用户身份。

常用数字签名算法有:

  • MD5withRSA
  • SHA1withRSA
  • SHA256withRSA

它们实际上就是指定某种哈希算法进行RSA签名的方式。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import java.nio.charset.StandardCharsets;
import java.security.*;
import java.util.HexFormat;

public class Main {
public static void main(String[] args) throws GeneralSecurityException {
// 生成RSA公钥/私钥:
KeyPairGenerator kpGen = KeyPairGenerator.getInstance("RSA");
kpGen.initialize(1024);
KeyPair kp = kpGen.generateKeyPair();
PrivateKey sk = kp.getPrivate();
PublicKey pk = kp.getPublic();

// 待签名的消息:
byte[] message = "Hello, I am Bob!".getBytes(StandardCharsets.UTF_8);

// 用私钥签名:
Signature s = Signature.getInstance("SHA1withRSA");
s.initSign(sk);
s.update(message);
byte[] signed = s.sign();
System.out.println("signature: " + HexFormat.of().formatHex(signed));

// 用公钥验证:
Signature v = Signature.getInstance("SHA1withRSA");
v.initVerify(pk);
v.update(message);
boolean valid = v.verify(signed);
System.out.println("valid? " + valid);
}
}

使用其他公钥,或者验证签名的时候修改原始信息,都无法验证成功。

DSA签名

除了RSA可以签名外,还可以使用DSA算法进行签名。DSA是Digital Signature Algorithm的缩写,它使用ElGamal数字签名算法。

DSA只能配合SHA使用,常用的算法有:

  • SHA1withDSA
  • SHA256withDSA
  • SHA512withDSA

和RSA数字签名相比,DSA的优点是更快。

ECDSA签名

椭圆曲线签名算法ECDSA:Elliptic Curve Digital Signature Algorithm也是一种常用的签名算法,它的特点是可以从私钥推出公钥。比特币的签名算法就采用了ECDSA算法,使用标准椭圆曲线secp256k1。BouncyCastle提供了ECDSA的完整实现。

练习

使用RSA算法进行签名。

下载练习

小结

数字签名就是用发送方的私钥对原始数据进行签名,只有用发送方公钥才能通过签名验证。

数字签名用于:

  • 防止伪造;
  • 防止抵赖;
  • 检测篡改。

常用的数字签名算法包括:MD5withRSA/SHA1withRSA/SHA256withRSA/SHA1withDSA/SHA256withDSA/SHA512withDSA/ECDSA等。



我们知道,摘要算法用来确保数据没有被篡改,非对称加密算法可以对数据进行加解密,签名算法可以确保数据完整性和抗否认性,把这些算法集合到一起,并搞一套完善的标准,这就是数字证书。

因此,数字证书就是集合了多种密码学算法,用于实现数据加解密、身份认证、签名等多种功能的一种安全标准。

数字证书可以防止中间人攻击,因为它采用链式签名认证,即通过根证书(Root CA)去签名下一级证书,这样层层签名,直到最终的用户证书。而Root CA证书内置于操作系统中,所以,任何经过CA认证的数字证书都可以对其本身进行校验,确保证书本身不是伪造的。

我们在上网时常用的HTTPS协议就是数字证书的应用。浏览器会自动验证证书的有效性:

certificate

要使用数字证书,首先需要创建证书。正常情况下,一个合法的数字证书需要经过CA签名,这需要认证域名并支付一定的费用。开发的时候,我们可以使用自签名的证书,这种证书可以正常开发调试,但不能对外作为服务使用,因为其他客户端并不认可未经CA签名的证书。

注:腾讯云可申请有效期1年的免费SSL证书,Let’s Encrypt可申请有效期90天的免费SSL证书。

在Java程序中,数字证书存储在一种Java专用的key store文件中,JDK提供了一系列命令来创建和管理key store。我们用下面的命令创建一个key store,并设定口令123456:

1
keytool -storepass 123456 -genkeypair -keyalg RSA -keysize 1024 -sigalg SHA1withRSA -validity 3650 -alias mycert -keystore my.keystore -dname "CN=www.sample.com, OU=sample, O=sample, L=BJ, ST=BJ, C=CN"

几个主要的参数是:

  • keyalg:指定RSA加密算法;
  • sigalg:指定SHA1withRSA签名算法;
  • validity:指定证书有效期3650天;
  • alias:指定证书在程序中引用的名称;
  • dname:最重要的CN=www.sample.com指定了Common Name,如果证书用在HTTPS中,这个名称必须与域名完全一致。

执行上述命令,JDK会在当前目录创建一个my.keystore文件,并存储创建成功的一个私钥和一个证书,它的别名是mycert

有了key store存储的证书,我们就可以通过数字证书进行加解密和签名:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import java.io.InputStream;
import java.security.*;
import java.security.cert.*;
import javax.crypto.Cipher;
import java.util.HexFormat;

public class Main {
public static void main(String[] args) throws Exception {
byte[] message = "Hello, use X.509 cert!".getBytes("UTF-8");
// 读取KeyStore:
KeyStore ks = loadKeyStore("/my.keystore", "123456");
// 读取私钥:
PrivateKey privateKey = (PrivateKey) ks.getKey("mycert", "123456".toCharArray());
// 读取证书:
X509Certificate certificate = (X509Certificate) ks.getCertificate("mycert");
// 加密:
byte[] encrypted = encrypt(certificate, message);
System.out.println("encrypted: " + HexFormat.of().formatHex(encrypted));
// 解密:
byte[] decrypted = decrypt(privateKey, encrypted);
System.out.println("decrypted: " + new String(decrypted, "UTF-8"));
// 签名:
byte[] sign = sign(privateKey, certificate, message);
System.out.println("signature: " + HexFormat.of().formatHex(sign));
// 验证签名:
boolean verified = verify(certificate, message, sign);
System.out.println("verify: " + verified);
}

static KeyStore loadKeyStore(String keyStoreFile, String password) {
try (InputStream input = Main.class.getResourceAsStream(keyStoreFile)) {
if (input == null) {
throw new RuntimeException("file not found in classpath: " + keyStoreFile);
}
KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType());
ks.load(input, password.toCharArray());
return ks;
} catch (Exception e) {
throw new RuntimeException(e);
}
}

static byte[] encrypt(X509Certificate certificate, byte[] message) throws GeneralSecurityException {
Cipher cipher = Cipher.getInstance(certificate.getPublicKey().getAlgorithm());
cipher.init(Cipher.ENCRYPT_MODE, certificate.getPublicKey());
return cipher.doFinal(message);
}

static byte[] decrypt(PrivateKey privateKey, byte[] data) throws GeneralSecurityException {
Cipher cipher = Cipher.getInstance(privateKey.getAlgorithm());
cipher.init(Cipher.DECRYPT_MODE, privateKey);
return cipher.doFinal(data);
}

static byte[] sign(PrivateKey privateKey, X509Certificate certificate, byte[] message)
throws GeneralSecurityException {
Signature signature = Signature.getInstance(certificate.getSigAlgName());
signature.initSign(privateKey);
signature.update(message);
return signature.sign();
}

static boolean verify(X509Certificate certificate, byte[] message, byte[] sig) throws GeneralSecurityException {
Signature signature = Signature.getInstance(certificate.getSigAlgName());
signature.initVerify(certificate);
signature.update(message);
return signature.verify(sig);
}
}

在上述代码中,我们从key store直接读取了私钥-公钥对,私钥以PrivateKey实例表示,公钥以X509Certificate表示,实际上数字证书只包含公钥,因此,读取证书并不需要口令,只有读取私钥才需要。如果部署到Web服务器上,例如Nginx,需要把私钥导出为Private Key格式,把证书导出为X509Certificate格式。

以HTTPS协议为例,浏览器和服务器建立安全连接的步骤如下:

  1. 浏览器向服务器发起请求,服务器向浏览器发送自己的数字证书;
  2. 浏览器用操作系统内置的Root CA来验证服务器的证书是否有效,如果有效,就使用该证书加密一个随机的AES口令并发送给服务器;
  3. 服务器用自己的私钥解密获得AES口令,并在后续通讯中使用AES加密。

上述流程只是一种最常见的单向验证。如果服务器还要验证客户端,那么客户端也需要把自己的证书发送给服务器验证,这种场景常见于网银等。

注意:数字证书存储的是公钥,以及相关的证书链和算法信息。私钥必须严格保密,如果数字证书对应的私钥泄漏,就会造成严重的安全威胁。如果CA证书的私钥泄漏,那么该CA证书签发的所有证书将不可信。数字证书服务商DigiNotar就发生过私钥泄漏导致公司破产的事故。

练习

使用数字证书实现消息加密。

下载练习

小结

数字证书就是集合了多种密码学算法,用于实现数据加解密、身份认证、签名等多种功能的一种安全标准;

数字证书采用链式签名管理,顶级的Root CA证书已内置在操作系统中;

数字证书存储的是公钥,可以安全公开,而私钥必须严格保密。

留言與分享

JAVA-正则表达式

分類 编程语言, Java

正则表达式简介

在了解正则表达式之前,我们先看几个非常常见的问题:

  • 如何判断字符串是否是有效的电话号码?例如:010-1234567123ABC45613510001000等;
  • 如何判断字符串是否是有效的电子邮件地址?例如:test@example.comtest#example等;
  • 如何判断字符串是否是有效的时间?例如:12:3409:6099:99等。

一种直观的想法是通过程序判断,这种方法需要为每种用例创建规则,然后用代码实现。下面是判断手机号的代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
boolean isValidMobileNumber(String s) {
// 是否是11位?
if (s.length() != 11) {
return false;
}
// 每一位都是0~9:
for (int i=0; i<s.length(); i++) {
char c = s.charAt(i);
if (c < '0' || c > '9') {
return false;
}
}
return true;
}

上述代码仅仅做了非常粗略的判断,并未考虑首位数字不能为0等更详细的情况。

除了判断手机号,我们还需要判断电子邮件地址、电话、邮编等等:

  • boolean isValidMobileNumber(String s) { … }
  • boolean isValidEmail(String s) { … }
  • boolean isValidPhoneNumber(String s) { … }
  • boolean isValidZipCode(String s) { … }

为每一种判断逻辑编写代码实在是太繁琐了。有没有更简单的方法?

有!用正则表达式!

正则表达式可以用字符串来描述规则,并用来匹配字符串。例如,判断手机号,我们用正则表达式\d{11}

1
2
3
boolean isValidMobileNumber(String s) {
return s.matches("\\d{11}");
}

使用正则表达式的好处有哪些?一个正则表达式就是一个描述规则的字符串,所以,只需要编写正确的规则,我们就可以让正则表达式引擎去判断目标字符串是否符合规则。

正则表达式是一套标准,它可以用于任何语言。Java标准库的java.util.regex包内置了正则表达式引擎,在Java程序中使用正则表达式非常简单。

举个例子:要判断用户输入的年份是否是20##年,我们先写出规则如下:

一共有4个字符,分别是:200~9任意数字0~9任意数字

对应的正则表达式就是:20\d\d,其中\d表示任意一个数字。

把正则表达式转换为Java字符串就变成了20\\d\\d,注意Java字符串用\\表示\

最后,用正则表达式匹配一个字符串的代码如下:

1
2
3
4
5
6
7
8
// regex
public class Main {
public static void main(String[] args) {
String regex = "20\\d\\d";
System.out.println("2019".matches(regex)); // true
System.out.println("2100".matches(regex)); // false
}
}

可见,使用正则表达式,不必编写复杂的代码来判断,只需给出一个字符串表达的正则规则即可。

小结

正则表达式是用字符串描述的一个匹配规则,使用正则表达式可以快速判断给定的字符串是否符合匹配规则;

Java标准库java.util.regex内建了正则表达式引擎。



正则表达式的匹配规则是从左到右按规则匹配。我们首先来看如何使用正则表达式来做精确匹配。

对于正则表达式abc来说,它只能精确地匹配字符串"abc",不能匹配"ab""Abc""abcd"等其他任何字符串。

如果正则表达式有特殊字符,那就需要用\转义。例如,正则表达式a\&c,其中\&是用来匹配特殊字符&的,它能精确匹配字符串"a&c",但不能匹配"ac""a-c""a&&c"等。

要注意正则表达式在Java代码中也是一个字符串,所以,对于正则表达式a\&c来说,对应的Java字符串是"a\\&c",因为\也是Java字符串的转义字符,两个\\实际上表示的是一个\

1
2
3
4
5
6
7
8
9
10
11
12
13
14
// regex
public class Main {
public static void main(String[] args) {
String re1 = "abc";
System.out.println("abc".matches(re1));
System.out.println("Abc".matches(re1));
System.out.println("abcd".matches(re1));

String re2 = "a\\&c"; // 对应的正则是a\&c
System.out.println("a&c".matches(re2));
System.out.println("a-c".matches(re2));
System.out.println("a&&c".matches(re2));
}
}

如果想匹配非ASCII字符,例如中文,那就用\u####的十六进制表示,例如:a\u548cc匹配字符串"a和c",中文字符的Unicode编码是548c

匹配任意字符

精确匹配实际上用处不大,因为我们直接用String.equals()就可以做到。大多数情况下,我们想要的匹配规则更多的是模糊匹配。我们可以用.匹配一个任意字符。

例如,正则表达式a.c中间的.可以匹配一个任意字符,例如,下面的字符串都可以被匹配:

  • "abc",因为.可以匹配字符b
  • "a&c",因为.可以匹配字符&
  • "acc",因为.可以匹配字符c

但它不能匹配"ac""a&&c",因为.匹配一个字符且仅限一个字符。

匹配数字

.可以匹配任意字符,这个口子开得有点大。如果我们只想匹配0~9这样的数字,可以用\d匹配。例如,正则表达式00\d可以匹配:

  • "007",因为\d可以匹配字符7
  • "008",因为\d可以匹配字符8

它不能匹配"00A""0077",因为\d仅限单个数字字符。

匹配常用字符

\w可以匹配一个字母、数字或下划线,w的意思是word。例如,java\w可以匹配:

  • "javac",因为\w可以匹配英文字符c
  • "java9",因为\w可以匹配数字字符9;。
  • "java_",因为\w可以匹配下划线_

它不能匹配"java#""java ",因为\w不能匹配#、空格等字符。

匹配空格字符

\s可以匹配一个空格字符,注意空格字符不但包括空格 ,还包括tab字符(在Java中用\t表示)。例如,a\sc可以匹配:

  • "a c",因为\s可以匹配空格字符 ;
  • "a c",因为\s可以匹配tab字符\t

它不能匹配"ac""abc"等。

匹配非数字

\d可以匹配一个数字,而\D则匹配一个非数字。例如,00\D可以匹配:

  • "00A",因为\D可以匹配非数字字符A
  • "00#",因为\D可以匹配非数字字符#

00\d可以匹配的字符串"007""008"等,00\D是不能匹配的。

类似的,\W可以匹配\w不能匹配的字符,\S可以匹配\s不能匹配的字符,这几个正好是反着来的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
// regex
public class Main {
public static void main(String[] args) {
String re1 = "java\\d"; // 对应的正则是java\d
System.out.println("java9".matches(re1));
System.out.println("java10".matches(re1));
System.out.println("javac".matches(re1));

String re2 = "java\\D";
System.out.println("javax".matches(re2));
System.out.println("java#".matches(re2));
System.out.println("java5".matches(re2));
}
}

重复匹配

我们用\d可以匹配一个数字,例如,A\d可以匹配"A0""A1",如果要匹配多个数字,比如"A380",怎么办?

修饰符*可以匹配任意个字符,包括0个字符。我们用A\d*可以匹配:

  • A:因为\d*可以匹配0个数字;
  • A0:因为\d*可以匹配1个数字0
  • A380:因为\d*可以匹配多个数字380

修饰符+可以匹配至少一个字符。我们用A\d+可以匹配:

  • A0:因为\d+可以匹配1个数字0
  • A380:因为\d+可以匹配多个数字380

但它无法匹配"A",因为修饰符+要求至少一个字符。

修饰符?可以匹配0个或一个字符。我们用A\d?可以匹配:

  • A:因为\d?可以匹配0个数字;
  • A0:因为\d?可以匹配1个数字0

但它无法匹配"A380",因为修饰符?超过1个字符就不能匹配了。

如果我们想精确指定n个字符怎么办?用修饰符{n}就可以。A\d{3}可以精确匹配:

  • A380:因为\d{3}可以匹配3个数字380

如果我们想指定匹配n~m个字符怎么办?用修饰符{n,m}就可以。A\d{3,5}可以精确匹配:

  • A380:因为\d{3,5}可以匹配3个数字380
  • A3800:因为\d{3,5}可以匹配4个数字3800
  • A38000:因为\d{3,5}可以匹配5个数字38000

如果没有上限,那么修饰符{n,}就可以匹配至少n个字符。

练习

请编写一个正则表达式匹配国内的电话号码规则:3~4位区号加7~8位电话,中间用-连接,例如:010-12345678

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
// regex
import java.util.*;

public class Main {
public static void main(String[] args) throws Exception {
String re = "\\d";
for (String s : List.of("010-12345678", "020-9999999", "0755-7654321")) {
if (!s.matches(re)) {
System.out.println("测试失败: " + s);
return;
}
}
for (String s : List.of("010 12345678", "A20-9999999", "0755-7654.321")) {
if (s.matches(re)) {
System.out.println("测试失败: " + s);
return;
}
}
System.out.println("测试成功!");
}
}

下载练习

进阶:国内区号必须以0开头,而电话号码不能以0开头,试修改正则表达式,使之能更精确地匹配。

提示:\d\D这种简单的规则暂时做不到,我们需要更复杂规则,后面会详细讲解。

小结

单个字符的匹配规则如下:

正则表达式 规则 可以匹配
A 指定字符 A
\u548c 指定Unicode字符
. 任意字符 ab&0
\d 数字0~9 0~9
\w 大小写字母,数字和下划线 a~zA~Z0~9_
\s 空格、Tab键 空格,Tab
\D 非数字 aA&_,……
\W 非\w &@,……
\S 非\s aA&_,……

多个字符的匹配规则如下:

正则表达式 规则 可以匹配
A* 任意个数字符 空,AAAAAA,……
A+ 至少1个字符 AAAAAA,……
A? 0个或1个字符 空,A
A{3} 指定个数字符 AAA
A{2,3} 指定范围个数字符 AAAAA
A{2,} 至少n个字符 AAAAAAAAA,……
A{0,3} 最多n个字符 空,AAAAAA

复杂匹配规则

匹配开头和结尾

用正则表达式进行多行匹配时,我们用^表示开头,$表示结尾。例如,^A\d{3}$,可以匹配"A001""A380"

匹配指定范围

如果我们规定一个7~8位数字的电话号码不能以0开头,应该怎么写匹配规则呢?\d{7,8}是不行的,因为第一个\d可以匹配到0

使用[...]可以匹配范围内的字符,例如,[123456789]可以匹配1~9,这样就可以写出上述电话号码的规则:[123456789]\d{6,7}

把所有字符全列出来太麻烦,[...]还有一种写法,直接写[1-9]就可以。

要匹配大小写不限的十六进制数,比如1A2b3c,我们可以这样写:[0-9a-fA-F],它表示一共可以匹配以下任意范围的字符:

  • 0-9:字符0~9
  • a-f:字符a~f
  • A-F:字符A~F

如果要匹配6位十六进制数,前面讲过的{n}仍然可以继续配合使用:[0-9a-fA-F]{6}

[...]还有一种排除法,即不包含指定范围的字符。假设我们要匹配任意字符,但不包括数字,可以写[^1-9]{3}

  • 可以匹配"ABC",因为不包含字符1~9
  • 可以匹配"A00",因为不包含字符1~9
  • 不能匹配"A01",因为包含字符1
  • 不能匹配"A05",因为包含字符5

或规则匹配

|连接的两个正则规则是规则,例如,AB|CD表示可以匹配ABCD

我们来看这个正则表达式java|php

1
2
3
4
5
6
7
8
9
// regex
public class Main {
public static void main(String[] args) {
String re = "java|php";
System.out.println("java".matches(re));
System.out.println("php".matches(re));
System.out.println("go".matches(re));
}
}

它可以匹配"java""php",但无法匹配"go"

要把go也加进来匹配,可以改写为java|php|go

使用括号

现在我们想要匹配字符串learn javalearn phplearn go怎么办?一个最简单的规则是learn\sjava|learn\sphp|learn\sgo,但是这个规则太复杂了,可以把公共部分提出来,然后用(...)把子规则括起来表示成learn\s(java|php|go)

1
2
3
4
5
6
7
8
9
10
// regex
public class Main {
public static void main(String[] args) {
String re = "learn\\s(java|php|go)";
System.out.println("learn java".matches(re));
System.out.println("learn Java".matches(re));
System.out.println("learn php".matches(re));
System.out.println("learn Go".matches(re));
}
}

上面的规则仍然不能匹配learn Javalearn Go这样的字符串。试修改正则,使之能匹配大写字母开头的learn Javalearn Phplearn Go

小结

复杂匹配规则主要有:

正则表达式 规则 可以匹配
^ 开头 字符串开头
$ 结尾 字符串结束
[ABC] […]内任意字符 A,B,C
[A-F0-9xy] 指定范围的字符 A,……,F0,……,9xy
[^A-F] 指定范围外的任意字符 A~F
AB CD EF


我们前面讲到的(...)可以用来把一个子规则括起来,这样写learn\s(java|php|go)就可以更方便地匹配长字符串了。

实际上(...)还有一个重要作用,就是分组匹配。

我们来看一下如何用正则匹配区号-电话号码这个规则。利用前面讲到的匹配规则,写出来很容易:

1
\d{3,4}\-\d{6,8}

虽然这个正则匹配规则很简单,但是往往匹配成功后,下一步是提取区号和电话号码,分别存入数据库。于是问题来了:如何提取匹配的子串?

当然可以用String提供的indexOf()substring()这些方法,但它们从正则匹配的字符串中提取子串没有通用性,下一次要提取learn\s(java|php)还得改代码。

正确的方法是用(...)先把要提取的规则分组,把上述正则表达式变为(\d{3,4})\-(\d{6,8})

现在问题又来了:匹配后,如何按括号提取子串?

现在我们没办法用String.matches()这样简单的判断方法了,必须引入java.util.regex包,用Pattern对象匹配,匹配后获得一个Matcher对象,如果匹配成功,就可以直接从Matcher.group(index)返回子串:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import java.util.regex.*;

public class Main {
public static void main(String[] args) {
Pattern p = Pattern.compile("(\\d{3,4})\\-(\\d{7,8})");
Matcher m = p.matcher("010-12345678");
if (m.matches()) {
String g1 = m.group(1);
String g2 = m.group(2);
System.out.println(g1); // 010
System.out.println(g2); // 12345678
} else {
System.out.println("匹配失败!");
}
}
}

运行上述代码,会得到两个匹配上的子串01012345678

要特别注意,Matcher.group(index)方法的参数用1表示第一个子串,2表示第二个子串。如果我们传入0会得到什么呢?答案是010-12345678,即整个正则匹配到的字符串。

Pattern

我们在前面的代码中用到的正则表达式代码是String.matches()方法,而我们在分组提取的代码中用的是java.util.regex包里面的Pattern类和Matcher类。实际上这两种代码本质上是一样的,因为String.matches()方法内部调用的就是PatternMatcher类的方法。

但是反复使用String.matches()对同一个正则表达式进行多次匹配效率较低,因为每次都会创建出一样的Pattern对象。完全可以先创建出一个Pattern对象,然后反复使用,就可以实现编译一次,多次匹配:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import java.util.regex.*;

public class Main {
public static void main(String[] args) {
Pattern pattern = Pattern.compile("(\\d{3,4})\\-(\\d{7,8})");
pattern.matcher("010-12345678").matches(); // true
pattern.matcher("021-123456").matches(); // false
pattern.matcher("022#1234567").matches(); // false
// 获得Matcher对象:
Matcher matcher = pattern.matcher("010-12345678");
if (matcher.matches()) {
String whole = matcher.group(0); // "010-12345678", 0表示匹配的整个字符串
String area = matcher.group(1); // "010", 1表示匹配的第1个子串
String tel = matcher.group(2); // "12345678", 2表示匹配的第2个子串
System.out.println(area);
System.out.println(tel);
}
}
}

使用Matcher时,必须首先调用matches()判断是否匹配成功,匹配成功后,才能调用group()提取子串。

利用提取子串的功能,我们轻松获得了区号和号码两部分。

练习

利用分组匹配,从字符串"23:01:59"提取时、分、秒。

下载练习

小结

正则表达式用(...)分组可以通过Matcher对象快速提取子串:

  • group(0)表示匹配的整个字符串;
  • group(1)表示第1个子串,group(2)表示第2个子串,以此类推。

非贪婪匹配

在介绍非贪婪匹配前,我们先看一个简单的问题:

给定一个字符串表示的数字,判断该数字末尾0的个数。例如:

  • "123000":3个0
  • "10100":2个0
  • "1001":0个0

可以很容易地写出该正则表达式:(\d+)(0*),Java代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
import java.util.regex.*;

public class Main {
public static void main(String[] args) {
Pattern pattern = Pattern.compile("(\\d+)(0*)");
Matcher matcher = pattern.matcher("1230000");
if (matcher.matches()) {
System.out.println("group1=" + matcher.group(1)); // "1230000"
System.out.println("group2=" + matcher.group(2)); // ""
}
}
}

然而打印的第二个子串是空字符串""

实际上,我们期望分组匹配结果是:

input \d+ 0*
123000 “123” “000”
10100 “101” “00”
1001 “1001” “”

但实际的分组匹配结果是这样的:

input \d+ 0*
123000 “123000” “”
10100 “10100” “”
1001 “1001” “”

仔细观察上述实际匹配结果,实际上它是完全合理的,因为\d+确实可以匹配后面任意个0

这是因为正则表达式默认使用贪婪匹配:任何一个规则,它总是尽可能多地向后匹配,因此,\d+总是会把后面的0包含进来。

要让\d+尽量少匹配,让0*尽量多匹配,我们就必须让\d+使用非贪婪匹配。在规则\d+后面加个?即可表示非贪婪匹配。我们改写正则表达式如下:

1
2
3
4
5
6
7
8
9
10
11
12
import java.util.regex.*;

public class Main {
public static void main(String[] args) {
Pattern pattern = Pattern.compile("(\\d+?)(0*)");
Matcher matcher = pattern.matcher("1230000");
if (matcher.matches()) {
System.out.println("group1=" + matcher.group(1)); // "123"
System.out.println("group2=" + matcher.group(2)); // "0000"
}
}
}

因此,给定一个匹配规则,加上?后就变成了非贪婪匹配。

我们再来看这个正则表达式(\d??)(9*),注意\d?表示匹配0个或1个数字,后面第二个?表示非贪婪匹配,因此,给定字符串"9999",匹配到的两个子串分别是"""9999",因为对于\d?来说,可以匹配1个9,也可以匹配0个9,但是因为后面的?表示非贪婪匹配,它就会尽可能少的匹配,结果是匹配了0个9

小结

正则表达式匹配默认使用贪婪匹配,可以使用?表示对某一规则进行非贪婪匹配;

注意区分?的含义:\d??



分割字符串

使用正则表达式分割字符串可以实现更加灵活的功能。String.split()方法传入的正是正则表达式。我们来看下面的代码:

1
2
3
"a b c".split("\\s"); // { "a", "b", "c" }
"a b c".split("\\s"); // { "a", "b", "", "c" }
"a, b ;; c".split("[\\,\\;\\s]+"); // { "a", "b", "c" }

如果我们想让用户输入一组标签,然后把标签提取出来,因为用户的输入往往是不规范的,这时,使用合适的正则表达式,就可以消除多个空格、混合,;这些不规范的输入,直接提取出规范的字符串。

搜索字符串

使用正则表达式还可以搜索字符串,我们来看例子:

1
2
3
4
5
6
7
8
9
10
11
12
13
import java.util.regex.*;

public class Main {
public static void main(String[] args) {
String s = "the quick brown fox jumps over the lazy dog.";
Pattern p = Pattern.compile("\\wo\\w");
Matcher m = p.matcher(s);
while (m.find()) {
String sub = s.substring(m.start(), m.end());
System.out.println(sub);
}
}
}

我们获取到Matcher对象后,不需要调用matches()方法(因为匹配整个串肯定返回false),而是反复调用find()方法,在整个串中搜索能匹配上\\wo\\w规则的子串,并打印出来。这种方式比String.indexOf()要灵活得多,因为我们搜索的规则是3个字符:中间必须是o,前后两个必须是字符[A-Za-z0-9_]

替换字符串

使用正则表达式替换字符串可以直接调用String.replaceAll(),它的第一个参数是正则表达式,第二个参数是待替换的字符串。我们还是来看例子:

1
2
3
4
5
6
7
8
// regex
public class Main {
public static void main(String[] args) {
String s = "The quick\t\t brown fox jumps over the lazy dog.";
String r = s.replaceAll("\\s+", " ");
System.out.println(r); // "The quick brown fox jumps over the lazy dog."
}
}

上面的代码把不规范的连续空格分隔的句子变成了规范的句子。可见,灵活使用正则表达式可以大大降低代码量。

反向引用

如果我们要把搜索到的指定字符串按规则替换,比如前后各加一个<b>xxxx</b>,这个时候,使用replaceAll()的时候,我们传入的第二个参数可以使用$1$2来反向引用匹配到的子串。例如:

1
2
3
4
5
6
7
8
// regex
public class Main {
public static void main(String[] args) {
String s = "the quick brown fox jumps over the lazy dog.";
String r = s.replaceAll("\\s([a-z]{4})\\s", " <b>$1</b> ");
System.out.println(r);
}
}

上述代码的运行结果是:

1
the quick brown fox jumps <b>over</b> the <b>lazy</b> dog.

它实际上把任何4字符单词的前后用<b>xxxx</b>括起来。实现替换的关键就在于" <b>$1</b> ",它用匹配的分组子串([a-z]{4})替换了$1

练习

模板引擎是指,定义一个字符串作为模板:

1
Hello, ${name}! You are learning ${lang}!

其中,以${key}表示的是变量,也就是将要被替换的内容。

当传入一个Map<String, String>给模板后,需要把对应的key替换为Map的value。

例如,传入Map为:

1
2
3
4
{
"name": "Bob",
"lang": "Java"
}

然后,${name}被替换为Map对应的值Bob${lang}被替换为Map对应的值Java,最终输出的结果为:

1
Hello, Bob! You are learning Java!

请编写一个简单的模板引擎,利用正则表达式实现这个功能。

提示:参考Matcher.appendReplacement()方法。

下载练习

小结

使用正则表达式可以:

  • 分割字符串:String.split()
  • 搜索子串:Matcher.find()
  • 替换字符串:String.replaceAll()

留言與分享

JAVA-单元测试

分類 编程语言, Java

什么是单元测试呢?单元测试就是针对最小的功能单元编写测试代码。Java程序最小的功能单元是方法,因此,对Java程序进行单元测试就是针对单个Java方法的测试。

单元测试有什么好处呢?在学习单元测试前,我们可以先了解一下测试驱动开发。

所谓测试驱动开发,是指先编写接口,紧接着编写测试。编写完测试后,我们才开始真正编写实现代码。在编写实现代码的过程中,一边写,一边测,什么时候测试全部通过了,那就表示编写的实现完成了:

1
2
3
4
5
6
7
8
9
10
11
12
13
    编写接口


编写测试


┌─▶ 编写实现
│ │
│ N ▼
└── 运行测试
│ Y

任务完成

这就是传说中的……

tdd

当然,这是一种理想情况。大部分情况是我们已经编写了实现代码,需要对已有的代码进行测试。

我们先通过一个示例来看如何编写测试。假定我们编写了一个计算阶乘的类,它只有一个静态方法来计算阶乘:

代码如下:

1
2
3
4
5
6
7
8
9
public class Factorial {
public static long fact(long n) {
long r = 1;
for (long i = 1; i <= n; i++) {
r = r * i;
}
return r;
}
}

要测试这个方法,一个很自然的想法是编写一个main()方法,然后运行一些测试代码:

1
2
3
4
5
6
7
8
9
public class Test {
public static void main(String[] args) {
if (fact(10) == 3628800) {
System.out.println("pass");
} else {
System.out.println("fail");
}
}
}

这样我们就可以通过运行main()方法来运行测试代码。

不过,使用main()方法测试有很多缺点:

一是只能有一个main()方法,不能把测试代码分离,二是没有打印出测试结果和期望结果,例如,expected: 3628800, but actual: 123456,三是很难编写一组通用的测试代码。

因此,我们需要一种测试框架,帮助我们编写测试。

JUnit

JUnit是一个开源的Java语言的单元测试框架,专门针对Java设计,使用最广泛。JUnit是事实上的单元测试的标准框架,任何Java开发者都应当学习并使用JUnit编写单元测试。

使用JUnit编写单元测试的好处在于,我们可以非常简单地组织测试代码,并随时运行它们,JUnit就会给出成功的测试和失败的测试,还可以生成测试报告,不仅包含测试的成功率,还可以统计测试的代码覆盖率,即被测试的代码本身有多少经过了测试。对于高质量的代码来说,测试覆盖率应该在80%以上。

此外,几乎所有的IDE工具都集成了JUnit,这样我们就可以直接在IDE中编写并运行JUnit测试。JUnit目前最新版本是5。

以Eclipse为例,当我们已经编写了一个Factorial.java文件后,我们想对其进行测试,需要编写一个对应的FactorialTest.java文件,以Test为后缀是一个惯例,并分别将其放入srctest目录中。最后,在Project - Properties - Java Build Path - Libraries中添加JUnit 5的库:

junit-lib

整个项目结构如下:

junit-test-structure

我们来看一下FactorialTest.java的内容:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
package com.itranswarp.learnjava;

import static org.junit.jupiter.api.Assertions.*;
import org.junit.jupiter.api.Test;

public class FactorialTest {

@Test
void testFact() {
assertEquals(1, Factorial.fact(1));
assertEquals(2, Factorial.fact(2));
assertEquals(6, Factorial.fact(3));
assertEquals(3628800, Factorial.fact(10));
assertEquals(2432902008176640000L, Factorial.fact(20));
}
}

核心测试方法testFact()加上了@Test注解,这是JUnit要求的,它会把带有@Test的方法识别为测试方法。在测试方法内部,我们用assertEquals(1, Factorial.fact(1))表示,期望Factorial.fact(1)返回1assertEquals(expected, actual)是最常用的测试方法,它在Assertion类中定义。Assertion还定义了其他断言方法,例如:

  • assertTrue(): 期待结果为true
  • assertFalse(): 期待结果为false
  • assertNotNull(): 期待结果为非null
  • assertArrayEquals(): 期待结果为数组并与期望数组每个元素的值均相等

运行单元测试非常简单。选中FactorialTest.java文件,点击Run - Run As - JUnit Test,Eclipse会自动运行这个JUnit测试,并显示结果:

junit-test-ok

如果测试结果与预期不符,assertEquals()会抛出异常,我们就会得到一个测试失败的结果:

junit-test-failed

在Failure Trace中,JUnit会告诉我们详细的错误结果:

1
2
3
4
5
6
7
8
9
org.opentest4j.AssertionFailedError: expected: <3628800> but was: <362880>
at org.junit.jupiter.api.AssertionUtils.fail(AssertionUtils.java:55)
at org.junit.jupiter.api.AssertEquals.failNotEqual(AssertEquals.java:195)
at org.junit.jupiter.api.AssertEquals.assertEquals(AssertEquals.java:168)
at org.junit.jupiter.api.AssertEquals.assertEquals(AssertEquals.java:163)
at org.junit.jupiter.api.Assertions.assertEquals(Assertions.java:611)
at com.itranswarp.learnjava.FactorialTest.testFact(FactorialTest.java:14)
at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at ...

第一行的失败信息的意思是期待结果3628800但是实际返回是362880,此时,我们要么修正实现代码,要么修正测试代码,直到测试通过为止。

使用浮点数时,由于浮点数无法精确地进行比较,因此,我们需要调用assertEquals(double expected, double actual, double delta)这个重载方法,指定一个误差值:

1
assertEquals(0.1, Math.abs(1 - 9 / 10.0), 0.0000001);

单元测试的好处

单元测试可以确保单个方法按照正确预期运行,如果修改了某个方法的代码,只需确保其对应的单元测试通过,即可认为改动正确。此外,测试代码本身就可以作为示例代码,用来演示如何调用该方法。

使用JUnit进行单元测试,我们可以使用断言(Assertion)来测试期望结果,可以方便地组织和运行测试,并方便地查看测试结果。此外,JUnit既可以直接在IDE中运行,也可以方便地集成到Maven这些自动化工具中运行。

在编写单元测试的时候,我们要遵循一定的规范:

一是单元测试代码本身必须非常简单,能一下看明白,决不能再为测试代码编写测试;

二是每个单元测试应当互相独立,不依赖运行的顺序;

三是测试时不但要覆盖常用测试用例,还要特别注意测试边界条件,例如输入为0null,空字符串""等情况。

练习

使用JUnit编写测试代码。

下载练习

小结

JUnit是一个单元测试框架,专门用于运行我们编写的单元测试:

一个JUnit测试包含若干@Test方法,并使用Assertions进行断言,注意浮点数assertEquals()要指定delta

在一个单元测试中,我们经常编写多个@Test方法,来分组、分类对目标代码进行测试。

在测试的时候,我们经常遇到一个对象需要初始化,测试完可能还需要清理的情况。如果每个@Test方法都写一遍这样的重复代码,显然比较麻烦。

JUnit提供了编写测试前准备、测试后清理的固定代码,我们称之为Fixture。

我们来看一个具体的Calculator的例子:

1
2
3
4
5
6
7
8
9
10
11
12
13
public class Calculator {
private long n = 0;

public long add(long x) {
n = n + x;
return n;
}

public long sub(long x) {
n = n - x;
return n;
}
}

这个类的功能很简单,但是测试的时候,我们要先初始化对象,我们不必在每个测试方法中都写上初始化代码,而是通过@BeforeEach来初始化,通过@AfterEach来清理资源:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
public class CalculatorTest {

Calculator calculator;

@BeforeEach
public void setUp() {
this.calculator = new Calculator();
}

@AfterEach
public void tearDown() {
this.calculator = null;
}

@Test
void testAdd() {
assertEquals(100, this.calculator.add(100));
assertEquals(150, this.calculator.add(50));
assertEquals(130, this.calculator.add(-20));
}

@Test
void testSub() {
assertEquals(-100, this.calculator.sub(100));
assertEquals(-150, this.calculator.sub(50));
assertEquals(-130, this.calculator.sub(-20));
}
}

CalculatorTest测试中,有两个标记为@BeforeEach@AfterEach的方法,它们会在运行每个@Test方法前后自动运行。

上面的测试代码在JUnit中运行顺序如下:

1
2
3
4
5
6
for (Method testMethod : findTestMethods(CalculatorTest.class)) {
var test = new CalculatorTest(); // 创建Test实例
invokeBeforeEach(test);
invokeTestMethod(test, testMethod);
invokeAfterEach(test);
}

可见,@BeforeEach@AfterEach会“环绕”在每个@Test方法前后。

还有一些资源初始化和清理可能更加繁琐,而且会耗费较长的时间,例如初始化数据库。JUnit还提供了@BeforeAll@AfterAll,它们在运行所有@Test前后运行,顺序如下:

1
2
3
4
5
6
7
8
invokeBeforeAll(CalculatorTest.class);
for (Method testMethod : findTestMethods(CalculatorTest.class)) {
var test = new CalculatorTest(); // 创建Test实例
invokeBeforeEach(test);
invokeTestMethod(test, testMethod);
invokeAfterEach(test);
}
invokeAfterAll(CalculatorTest.class);

因为@BeforeAll@AfterAll在所有@Test方法运行前后仅运行一次,因此,它们只能初始化静态变量,例如:

1
2
3
4
5
6
7
8
9
10
11
12
13
public class DatabaseTest {
static Database db;

@BeforeAll
public static void initDatabase() {
db = createDb(...);
}

@AfterAll
public static void dropDatabase() {
...
}
}

事实上,@BeforeAll@AfterAll也只能标注在静态方法上。

因此,我们总结出编写Fixture的套路如下:

  1. 对于实例变量,在@BeforeEach中初始化,在@AfterEach中清理,它们在各个@Test方法中互不影响,因为是不同的实例;
  2. 对于静态变量,在@BeforeAll中初始化,在@AfterAll中清理,它们在各个@Test方法中均是唯一实例,会影响各个@Test方法。

大多数情况下,使用@BeforeEach@AfterEach就足够了。只有某些测试资源初始化耗费时间太长,以至于我们不得不尽量“复用”时才会用到@BeforeAll@AfterAll

最后,注意到每次运行一个@Test方法前,JUnit首先创建一个XxxTest实例,因此,每个@Test方法内部的成员变量都是独立的,不能也无法把成员变量的状态从一个@Test方法带到另一个@Test方法。

练习

使用Fixture编写单元测试。

下载练习

小结

编写Fixture是指针对每个@Test方法,编写@BeforeEach方法用于初始化测试资源,编写@AfterEach用于清理测试资源;

必要时,可以编写@BeforeAll@AfterAll,使用静态变量来初始化耗时的资源,并且在所有@Test方法的运行前后仅执行一次。

异常测试

在Java程序中,异常处理是非常重要的。

我们自己编写的方法,也经常抛出各种异常。对于可能抛出的异常进行测试,本身就是测试的重要环节。

因此,在编写JUnit测试的时候,除了正常的输入输出,我们还要特别针对可能导致异常的情况进行测试。

我们仍然用Factorial举例:

1
2
3
4
5
6
7
8
9
10
11
12
public class Factorial {
public static long fact(long n) {
if (n < 0) {
throw new IllegalArgumentException();
}
long r = 1;
for (long i = 1; i <= n; i++) {
r = r * i;
}
return r;
}
}

在方法入口,我们增加了对参数n的检查,如果为负数,则直接抛出IllegalArgumentException

现在,我们希望对异常进行测试。在JUnit测试中,我们可以编写一个@Test方法专门测试异常:

1
2
3
4
5
6
7
8
9
@Test
void testNegative() {
assertThrows(IllegalArgumentException.class, new Executable() {
@Override
public void execute() throws Throwable {
Factorial.fact(-1);
}
});
}

JUnit提供assertThrows()来期望捕获一个指定的异常。第二个参数Executable封装了我们要执行的会产生异常的代码。当我们执行Factorial.fact(-1)时,必定抛出IllegalArgumentExceptionassertThrows()在捕获到指定异常时表示通过测试,未捕获到异常,或者捕获到的异常类型不对,均表示测试失败。

有些童鞋会觉得编写一个Executable的匿名类太繁琐了。实际上,Java 8开始引入了函数式编程,所有单方法接口都可以简写如下:

1
2
3
4
5
6
@Test
void testNegative() {
assertThrows(IllegalArgumentException.class, () -> {
Factorial.fact(-1);
});
}

上述奇怪的->语法就是函数式接口的实现代码,我们会在后面详细介绍。现在,我们只需要通过这种固定的代码编写能抛出异常的语句即可。

练习

观察Factorial.fact()方法,注意到由于long型整数有范围限制,当我们传入参数21时,得到的结果是-4249290049419214848,而不是期望的51090942171709440000,因此,当传入参数大于20时,Factorial.fact()方法应当抛出ArithmeticException。请编写测试并修改实现代码,确保测试通过。

下载练习

小结

测试异常可以使用assertThrows(),期待捕获到指定类型的异常;

对可能发生的每种类型的异常都必须进行测试。



在运行测试的时候,有些时候,我们需要排出某些@Test方法,不要让它运行,这时,我们就可以给它标记一个@Disabled

1
2
3
4
5
@Disabled
@Test
void testBug101() {
// 这个测试不会运行
}

为什么我们不直接注释掉@Test,而是要加一个@Disabled?这是因为注释掉@Test,JUnit就不知道这是个测试方法,而加上@Disabled,JUnit仍然识别出这是个测试方法,只是暂时不运行。它会在测试结果中显示:

1
Tests run: 68, Failures: 2, Errors: 0, Skipped: 5

类似@Disabled这种注解就称为条件测试(Conditional Test),JUnit根据不同的条件注解,决定是否运行当前的@Test方法。

我们来看一个例子:

1
2
3
4
5
6
7
8
9
10
11
12
public class Config {
public String getConfigFile(String filename) {
String os = System.getProperty("os.name").toLowerCase();
if (os.contains("win")) {
return "C:\\" + filename;
}
if (os.contains("mac") || os.contains("linux") || os.contains("unix")) {
return "/usr/local/" + filename;
}
throw new UnsupportedOperationException();
}
}

我们想要测试getConfigFile()这个方法,但是在Windows上跑,和在Linux上跑的代码路径不同,因此,针对两个系统的测试方法,其中一个只能在Windows上跑,另一个只能在Mac/Linux上跑:

1
2
3
4
5
6
7
8
9
@Test
void testWindows() {
assertEquals("C:\\test.ini", config.getConfigFile("test.ini"));
}

@Test
void testLinuxAndMac() {
assertEquals("/usr/local/test.cfg", config.getConfigFile("test.cfg"));
}

因此,我们给上述两个测试方法分别加上条件如下:

1
2
3
4
5
6
7
8
9
10
11
@Test
@EnabledOnOs(OS.WINDOWS)
void testWindows() {
assertEquals("C:\\test.ini", config.getConfigFile("test.ini"));
}

@Test
@EnabledOnOs({ OS.LINUX, OS.MAC })
void testLinuxAndMac() {
assertEquals("/usr/local/test.cfg", config.getConfigFile("test.cfg"));
}

@EnableOnOs就是一个条件测试判断。

我们来看一些常用的条件测试:

不在Windows平台执行的测试,可以加上@DisabledOnOs(OS.WINDOWS)

1
2
3
4
5
@Test
@DisabledOnOs(OS.WINDOWS)
void testOnNonWindowsOs() {
// TODO: this test is disabled on windows
}

只能在Java 9或更高版本执行的测试,可以加上@DisabledOnJre(JRE.JAVA_8)

1
2
3
4
5
@Test
@DisabledOnJre(JRE.JAVA_8)
void testOnJava9OrAbove() {
// TODO: this test is disabled on java 8
}

只能在64位操作系统上执行的测试,可以用@EnabledIfSystemProperty判断:

1
2
3
4
5
@Test
@EnabledIfSystemProperty(named = "os.arch", matches = ".*64.*")
void testOnlyOn64bitSystem() {
// TODO: this test is only run on 64 bit system
}

需要传入环境变量DEBUG=true才能执行的测试,可以用@EnabledIfEnvironmentVariable

1
2
3
4
5
@Test
@EnabledIfEnvironmentVariable(named = "DEBUG", matches = "true")
void testOnlyOnDebugMode() {
// TODO: this test is only run on DEBUG=true
}

当我们在JUnit中运行所有测试的时候,JUnit会给出执行的结果。在IDE中,我们能很容易地看到没有执行的测试:

junit-conditional-test

带有⊘标记的测试方法表示没有执行。

练习

在JUnit测试中使用条件测试。

下载练习

小结

条件测试是根据某些注解在运行期让JUnit自动忽略某些测试。

如果待测试的输入和输出是一组数据,可以把测试数据组织起来,用不同的测试数据调用相同的测试方法,这就是参数化测试。

参数化测试和普通测试稍微不同的地方在于,一个测试方法需要接收至少一个参数,然后,传入一组参数反复运行。

JUnit提供了一个@ParameterizedTest注解,用来进行参数化测试。

假设我们想对Math.abs()进行测试,先用一组正数进行测试:

1
2
3
4
5
@ParameterizedTest
@ValueSource(ints = { 0, 1, 5, 100 })
void testAbs(int x) {
assertEquals(x, Math.abs(x));
}

再用一组负数进行测试:

1
2
3
4
5
@ParameterizedTest
@ValueSource(ints = { -1, -5, -100 })
void testAbsNegative(int x) {
assertEquals(-x, Math.abs(x));
}

注意到参数化测试的注解是@ParameterizedTest,而不是普通的@Test

实际的测试场景往往没有这么简单。假设我们自己编写了一个StringUtils.capitalize()方法,它会把字符串的第一个字母变为大写,后续字母变为小写:

1
2
3
4
5
6
7
8
public class StringUtils {
public static String capitalize(String s) {
if (s.length() == 0) {
return s;
}
return Character.toUpperCase(s.charAt(0)) + s.substring(1).toLowerCase();
}
}

要用参数化测试的方法来测试,我们不但要给出输入,还要给出预期输出。因此,测试方法至少需要接收两个参数:

1
2
3
4
@ParameterizedTest
void testCapitalize(String input, String result) {
assertEquals(result, StringUtils.capitalize(input));
}

现在问题来了:参数如何传入?

最简单的方法是通过@MethodSource注解,它允许我们编写一个同名的静态方法来提供测试参数:

1
2
3
4
5
6
7
8
9
10
11
12
@ParameterizedTest
@MethodSource
void testCapitalize(String input, String result) {
assertEquals(result, StringUtils.capitalize(input));
}

static List<Arguments> testCapitalize() {
return List.of( // arguments:
Arguments.of("abc", "Abc"), //
Arguments.of("APPLE", "Apple"), //
Arguments.of("gooD", "Good"));
}

上面的代码很容易理解:静态方法testCapitalize()返回了一组测试参数,每个参数都包含两个String,正好作为测试方法的两个参数传入。

提示

如果静态方法和测试方法的名称不同,@MethodSource也允许指定方法名。但使用默认同名方法最方便。

另一种传入测试参数的方法是使用@CsvSource,它的每一个字符串表示一行,一行包含的若干参数用,分隔,因此,上述测试又可以改写如下:

1
2
3
4
5
@ParameterizedTest
@CsvSource({ "abc, Abc", "APPLE, Apple", "gooD, Good" })
void testCapitalize(String input, String result) {
assertEquals(result, StringUtils.capitalize(input));
}

如果有成百上千的测试输入,那么,直接写@CsvSource就很不方便。这个时候,我们可以把测试数据提到一个独立的CSV文件中,然后标注上@CsvFileSource

1
2
3
4
5
@ParameterizedTest
@CsvFileSource(resources = { "/test-capitalize.csv" })
void testCapitalizeUsingCsvFile(String input, String result) {
assertEquals(result, StringUtils.capitalize(input));
}

JUnit只在classpath中查找指定的CSV文件,因此,test-capitalize.csv这个文件要放到test目录下,内容如下:

1
2
3
4
apple, Apple
HELLO, Hello
JUnit, Junit
reSource, Resource

练习

StringUtils进行参数化测试。

下载练习

小结

使用参数化测试,可以提供一组测试数据,对一个测试方法反复测试;

参数既可以在测试代码中写死,也可以通过@CsvFileSource放到外部的CSV文件中。

留言與分享

JAVA-日期与时间

分類 编程语言, Java

在计算机中,我们经常需要处理日期和时间。

这是日期:

  • 2019-11-20
  • 2020-1-1

这是时间:

  • 12:30:59
  • 2020-1-1 20:21:59

日期是指某一天,它不是连续变化的,而是应该被看成离散的。

而时间有两种概念,一种是不带日期的时间,例如,12:30:59。另一种是带日期的时间,例如,2020-1-1 20:21:59,只有这种带日期的时间能唯一确定某个时刻,不带日期的时间是无法确定一个唯一时刻的。

本地时间

当我们说当前时刻是2019年11月20日早上8:15的时候,我们说的实际上是本地时间。在国内就是北京时间。在这个时刻,如果地球上不同地方的人们同时看一眼手表,他们各自的本地时间是不同的:

localtime

所以,不同的时区,在同一时刻,本地时间是不同的。全球一共分为24个时区,伦敦所在的时区称为标准时区,其他时区按东/西偏移的小时区分,北京所在的时区是东八区。

时区

因为光靠本地时间还无法唯一确定一个准确的时刻,所以我们还需要给本地时间加上一个时区。时区有好几种表示方式。

一种是以GMT或者UTC加时区偏移表示,例如:GMT+08:00或者UTC+08:00表示东八区。

GMTUTC可以认为基本是等价的,只是UTC使用更精确的原子钟计时,每隔几年会有一个闰秒,我们在开发程序的时候可以忽略两者的误差,因为计算机的时钟在联网的时候会自动与时间服务器同步时间。

另一种是缩写,例如,CST表示China Standard Time,也就是中国标准时间。但是CST也可以表示美国中部时间Central Standard Time USA,因此,缩写容易产生混淆,我们尽量不要使用缩写。

最后一种是以洲/城市表示,例如,Asia/Shanghai,表示上海所在地的时区。特别注意城市名称不是任意的城市,而是由国际标准组织规定的城市。

因为时区的存在,东八区的2019年11月20日早上8:15,和西五区的2019年11月19日晚上19:15,他们的时刻是相同的:

timezone

时刻相同的意思就是,分别在两个时区的两个人,如果在这一刻通电话,他们各自报出自己手表上的时间,虽然本地时间是不同的,但是这两个时间表示的时刻是相同的。

夏令时

时区还不是最复杂的,更复杂的是夏令时。所谓夏令时,就是夏天开始的时候,把时间往后拨1小时,夏天结束的时候,再把时间往前拨1小时。我们国家实行过一段时间夏令时,1992年就废除了,但是矫情的美国人到现在还在使用,所以时间换算更加复杂。

daynight-saving

因为涉及到夏令时,相同的时区,如果表示的方式不同,转换出的时间是不同的。我们举个栗子:

对于2019-11-20和2019-6-20两个日期来说,假设北京人在纽约:

  • 如果以GMT或者UTC作为时区,无论日期是多少,时间都是19:00
  • 如果以国家/城市表示,例如America/NewYork,虽然纽约也在西五区,但是,因为夏令时的存在,在不同的日期,GMT时间和纽约时间可能是不一样的:
时区 2019-11-20 2019-6-20
GMT-05:00 19:00 19:00
UTC-05:00 19:00 19:00
America/New_York 19:00 20:00

实行夏令时的不同地区,进入和退出夏令时的时间很可能是不同的。同一个地区,根据历史上是否实行过夏令时,标准时间在不同年份换算成当地时间也是不同的。因此,计算夏令时,没有统一的公式,必须按照一组给定的规则来算,并且,该规则要定期更新。

注意

计算夏令时请使用标准库提供的相关类,不要试图自己计算夏令时。

本地化

在计算机中,通常使用Locale表示一个国家或地区的日期、时间、数字、货币等格式。Locale语言_国家的字母缩写构成,例如,zh_CN表示中文+中国,en_US表示英文+美国。语言使用小写,国家使用大写。

对于日期来说,不同的Locale,例如,中国和美国的表示方式如下:

  • zh_CN:2016-11-30
  • en_US:11/30/2016

计算机用Locale在日期、时间、货币和字符串之间进行转换。一个电商网站会根据用户所在的Locale对用户显示如下:

中国用户 美国用户
购买价格 12000.00 12,000.00
购买日期 2016-11-30 11/30/2016

小结

在编写日期和时间的程序前,我们要准确理解日期、时间和时刻的概念;

由于存在本地时间,我们需要理解时区的概念,并且必须牢记由于夏令时的存在,同一地区用GMT/UTC和城市表示的时区可能导致时间不同;

计算机通过Locale来针对当地用户习惯格式化日期、时间、数字、货币等。

在计算机中,应该如何表示日期和时间呢?

我们经常看到的日期和时间表示方式如下:

  • 2019-11-20 0:15:00 GMT+00:00
  • 2019年11月20日8:15:00
  • 11/19/2019 19:15:00 America/New_York

如果直接以字符串的形式存储,那么不同的格式,不同的语言会让表示方式非常繁琐。

在理解日期和时间的表示方式之前,我们先要理解数据的存储和展示。

当我们定义一个整型变量并赋值时:

1
int n = 123400;

编译器会把上述字符串(程序源码就是一个字符串)编译成字节码。在程序的运行期,变量n指向的内存实际上是一个4字节区域:

1
2
3
┌──┬──┬──┬──┐
│00│01│e2│08│
└──┴──┴──┴──┘

注意到计算机内存除了二进制的0/1外没有其他任何格式。上述十六进制是为了简化表示。

当我们用System.out.println(n)打印这个整数的时候,实际上println()这个方法在内部把int类型转换成String类型,然后打印出字符串123400

类似的,我们也可以以十六进制的形式打印这个整数,或者,如果n表示一个价格,我们就以$123,400.00的形式来打印它:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
import java.text.*;
import java.util.*;

public class Main {
public static void main(String[] args) {
int n = 123400;
// 123400
System.out.println(n);
// 1e208
System.out.println(Integer.toHexString(n));
// $123,400.00
System.out.println(NumberFormat.getCurrencyInstance(Locale.US).format(n));
}
}

可见,整数123400是数据的存储格式,它的存储格式非常简单。而我们打印的各种各样的字符串,则是数据的展示格式。展示格式有多种形式,但本质上它就是一个转换方法:

1
String toDisplay(int n) { ... }

理解了数据的存储和展示,我们回头看看以下几种日期和时间:

  • 2019-11-20 0:15:01 GMT+00:00
  • 2019年11月20日8:15:01
  • 11/19/2019 19:15:01 America/New_York

它们实际上是数据的展示格式,分别按英国时区、中国时区、纽约时区对同一个时刻进行展示。而这个“同一个时刻”在计算机中存储的本质上只是一个整数,我们称它为Epoch Time

Epoch Time是计算从1970年1月1日零点(格林威治时区/GMT+00:00)到现在所经历的秒数,例如:

1574208900表示从1970年1月1日零点GMT时区到该时刻一共经历了1574208900秒,换算成伦敦、北京和纽约时间分别是:

1
2
3
1574208900 = 北京时间2019-11-20 8:15:00
= 伦敦时间2019-11-20 0:15:00
= 纽约时间2019-11-19 19:15:00

localtime

因此,在计算机中,只需要存储一个整数1574208900表示某一时刻。当需要显示为某一地区的当地时间时,我们就把它格式化为一个字符串:

1
String displayDateTime(int n, String timezone) { ... }

Epoch Time又称为时间戳,在不同的编程语言中,会有几种存储方式:

  • 以秒为单位的整数:1574208900,缺点是精度只能到秒;
  • 以毫秒为单位的整数:1574208900123,最后3位表示毫秒数;
  • 以秒为单位的浮点数:1574208900.123,小数点后面表示零点几秒。

它们之间转换非常简单。而在Java程序中,时间戳通常是用long表示的毫秒数,即:

1
long t = 1574208900123L;

转换成北京时间就是2019-11-20T8:15:00.123。要获取当前时间戳,可以使用System.currentTimeMillis(),这是Java程序获取时间戳最常用的方法。

标准库API

我们再来看一下Java标准库提供的API。Java标准库有两套处理日期和时间的API:

  • 一套定义在java.util这个包里面,主要包括DateCalendarTimeZone这几个类;
  • 一套新的API是在Java 8引入的,定义在java.time这个包里面,主要包括LocalDateTimeZonedDateTimeZoneId等。

为什么会有新旧两套API呢?因为历史遗留原因,旧的API存在很多问题,所以引入了新的API。

那么我们能不能跳过旧的API直接用新的API呢?如果涉及到遗留代码就不行,因为很多遗留代码仍然使用旧的API,所以目前仍然需要对旧的API有一定了解,很多时候还需要在新旧两种对象之间进行转换。

本节我们快速讲解旧API的常用类型和方法。

Date

java.util.Date是用于表示一个日期和时间的对象,注意与java.sql.Date区分,后者用在数据库中。如果观察Date的源码,可以发现它实际上存储了一个long类型的以毫秒表示的时间戳:

1
2
3
4
5
6
public class Date implements Serializable, Cloneable, Comparable<Date> {

private transient long fastTime;

...
}

我们来看Date的基本用法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import java.util.*;

public class Main {
public static void main(String[] args) {
// 获取当前时间:
Date date = new Date();
System.out.println(date.getYear() + 1900); // 必须加上1900
System.out.println(date.getMonth() + 1); // 0~11,必须加上1
System.out.println(date.getDate()); // 1~31,不能加1
// 转换为String:
System.out.println(date.toString());
// 转换为GMT时区:
System.out.println(date.toGMTString());
// 转换为本地时区:
System.out.println(date.toLocaleString());
}
}

注意getYear()返回的年份必须加上1900getMonth()返回的月份是0~11分别表示1~12月,所以要加1,而getDate()返回的日期范围是1~31,又不能加1。

打印本地时区表示的日期和时间时,不同的计算机可能会有不同的结果。如果我们想要针对用户的偏好精确地控制日期和时间的格式,就可以使用SimpleDateFormat对一个Date进行转换。它用预定义的字符串表示格式化:

  • yyyy:年
  • MM:月
  • dd: 日
  • HH: 小时
  • mm: 分钟
  • ss: 秒

我们来看如何以自定义的格式输出:

1
2
3
4
5
6
7
8
9
10
11
import java.text.*;
import java.util.*;

public class Main {
public static void main(String[] args) {
// 获取当前时间:
Date date = new Date();
var sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
System.out.println(sdf.format(date));
}
}

Java的格式化预定义了许多不同的格式,我们以MMME为例:

1
2
3
4
5
6
7
8
9
10
11
import java.text.*;
import java.util.*;

public class Main {
public static void main(String[] args) {
// 获取当前时间:
Date date = new Date();
var sdf = new SimpleDateFormat("E MMM dd, yyyy");
System.out.println(sdf.format(date));
}
}

上述代码在不同的语言环境会打印出类似Sun Sep 15, 2019这样的日期。可以从JDK文档查看详细的格式说明。一般来说,字母越长,输出越长。以M为例,假设当前月份是9月:

  • M:输出9
  • MM:输出09
  • MMM:输出Sep
  • MMMM:输出September

Date对象有几个严重的问题:它不能转换时区,除了toGMTString()可以按GMT+0:00输出外,Date总是以当前计算机系统的默认时区为基础进行输出。此外,我们也很难对日期和时间进行加减,计算两个日期相差多少天,计算某个月第一个星期一的日期等。

Calendar

Calendar可以用于获取并设置年、月、日、时、分、秒,它和Date比,主要多了一个可以做简单的日期和时间运算的功能。

我们来看Calendar的基本用法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import java.util.*;

public class Main {
public static void main(String[] args) {
// 获取当前时间:
Calendar c = Calendar.getInstance();
int y = c.get(Calendar.YEAR);
int m = 1 + c.get(Calendar.MONTH);
int d = c.get(Calendar.DAY_OF_MONTH);
int w = c.get(Calendar.DAY_OF_WEEK);
int hh = c.get(Calendar.HOUR_OF_DAY);
int mm = c.get(Calendar.MINUTE);
int ss = c.get(Calendar.SECOND);
int ms = c.get(Calendar.MILLISECOND);
System.out.println(y + "-" + m + "-" + d + " " + w + " " + hh + ":" + mm + ":" + ss + "." + ms);
}
}

注意到Calendar获取年月日这些信息变成了get(int field),返回的年份不必转换,返回的月份仍然要加1,返回的星期要特别注意,1~7分别表示周日,周一,……,周六。

Calendar只有一种方式获取,即Calendar.getInstance(),而且一获取到就是当前时间。如果我们想给它设置成特定的一个日期和时间,就必须先清除所有字段:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import java.text.*;
import java.util.*;

public class Main {
public static void main(String[] args) {
// 当前时间:
Calendar c = Calendar.getInstance();
// 清除所有:
c.clear();
// 设置2019年:
c.set(Calendar.YEAR, 2019);
// 设置9月:注意8表示9月:
c.set(Calendar.MONTH, 8);
// 设置2日:
c.set(Calendar.DATE, 2);
// 设置时间:
c.set(Calendar.HOUR_OF_DAY, 21);
c.set(Calendar.MINUTE, 22);
c.set(Calendar.SECOND, 23);
System.out.println(new SimpleDateFormat("yyyy-MM-dd HH:mm:ss").format(c.getTime()));
// 2019-09-02 21:22:23
}
}

利用Calendar.getTime()可以将一个Calendar对象转换成Date对象,然后就可以用SimpleDateFormat进行格式化了。

TimeZone

CalendarDate相比,它提供了时区转换的功能。时区用TimeZone对象表示:

1
2
3
4
5
6
7
8
9
10
11
12
import java.util.*;

public class Main {
public static void main(String[] args) {
TimeZone tzDefault = TimeZone.getDefault(); // 当前时区
TimeZone tzGMT9 = TimeZone.getTimeZone("GMT+09:00"); // GMT+9:00时区
TimeZone tzNY = TimeZone.getTimeZone("America/New_York"); // 纽约时区
System.out.println(tzDefault.getID()); // Asia/Shanghai
System.out.println(tzGMT9.getID()); // GMT+09:00
System.out.println(tzNY.getID()); // America/New_York
}
}

时区的唯一标识是以字符串表示的ID,我们获取指定TimeZone对象也是以这个ID为参数获取,GMT+09:00Asia/Shanghai都是有效的时区ID。要列出系统支持的所有ID,请使用TimeZone.getAvailableIDs()

有了时区,我们就可以对指定时间进行转换。例如,下面的例子演示了如何将北京时间2019-11-20 8:15:00转换为纽约时间:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import java.text.*;
import java.util.*;

public class Main {
public static void main(String[] args) {
// 当前时间:
Calendar c = Calendar.getInstance();
// 清除所有:
c.clear();
// 设置为北京时区:
c.setTimeZone(TimeZone.getTimeZone("Asia/Shanghai"));
// 设置年月日时分秒:
c.set(2019, 10 /* 11月 */, 20, 8, 15, 0);
// 显示时间:
var sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
sdf.setTimeZone(TimeZone.getTimeZone("America/New_York"));
System.out.println(sdf.format(c.getTime()));
// 2019-11-19 19:15:00
}
}

可见,利用Calendar进行时区转换的步骤是:

  1. 清除所有字段;
  2. 设定指定时区;
  3. 设定日期和时间;
  4. 创建SimpleDateFormat并设定目标时区;
  5. 格式化获取的Date对象(注意Date对象无时区信息,时区信息存储在SimpleDateFormat中)。

因此,本质上时区转换只能通过SimpleDateFormat在显示的时候完成。

Calendar也可以对日期和时间进行简单的加减:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import java.text.*;
import java.util.*;

public class Main {
public static void main(String[] args) {
// 当前时间:
Calendar c = Calendar.getInstance();
// 清除所有:
c.clear();
// 设置年月日时分秒:
c.set(2019, 10 /* 11月 */, 20, 8, 15, 0);
// 加5天并减去2小时:
c.add(Calendar.DAY_OF_MONTH, 5);
c.add(Calendar.HOUR_OF_DAY, -2);
// 显示时间:
var sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
Date d = c.getTime();
System.out.println(sdf.format(d));
// 2019-11-25 6:15:00
}
}

小结

计算机表示的时间是以整数表示的时间戳存储的,即Epoch Time,Java使用long型来表示以毫秒为单位的时间戳,通过System.currentTimeMillis()获取当前时间戳。

Java有两套日期和时间的API:

  • 旧的Date、Calendar和TimeZone;
  • 新的LocalDateTime、ZonedDateTime、ZoneId等。

分别位于java.utiljava.time包中。

从Java 8开始,java.time包提供了新的日期和时间API,主要涉及的类型有:

  • 本地日期和时间:LocalDateTimeLocalDateLocalTime
  • 带时区的日期和时间:ZonedDateTime
  • 时刻:Instant
  • 时区:ZoneIdZoneOffset
  • 时间间隔:Duration

以及一套新的用于取代SimpleDateFormat的格式化类型DateTimeFormatter

和旧的API相比,新API严格区分了时刻、本地日期、本地时间和带时区的日期时间,并且,对日期和时间进行运算更加方便。

此外,新API修正了旧API不合理的常量设计:

  • Month的范围用1~12表示1月到12月;
  • Week的范围用1~7表示周一到周日。

最后,新API的类型几乎全部是不变类型(和String类似),可以放心使用不必担心被修改。

LocalDateTime

我们首先来看最常用的LocalDateTime,它表示一个本地日期和时间:

1
2
3
4
5
6
7
8
9
10
11
12
import java.time.*;

public class Main {
public static void main(String[] args) {
LocalDate d = LocalDate.now(); // 当前日期
LocalTime t = LocalTime.now(); // 当前时间
LocalDateTime dt = LocalDateTime.now(); // 当前日期和时间
System.out.println(d); // 严格按照ISO 8601格式打印
System.out.println(t); // 严格按照ISO 8601格式打印
System.out.println(dt); // 严格按照ISO 8601格式打印
}
}

本地日期和时间通过now()获取到的总是以当前默认时区返回的,和旧API不同,LocalDateTimeLocalDateLocalTime默认严格按照ISO 8601规定的日期和时间格式进行打印。

上述代码其实有一个小问题,在获取3个类型的时候,由于执行一行代码总会消耗一点时间,因此,3个类型的日期和时间很可能对不上(时间的毫秒数基本上不同)。为了保证获取到同一时刻的日期和时间,可以改写如下:

1
2
3
LocalDateTime dt = LocalDateTime.now(); // 当前日期和时间
LocalDate d = dt.toLocalDate(); // 转换到当前日期
LocalTime t = dt.toLocalTime(); // 转换到当前时间

反过来,通过指定的日期和时间创建LocalDateTime可以通过of()方法:

1
2
3
4
5
// 指定日期和时间:
LocalDate d2 = LocalDate.of(2019, 11, 30); // 2019-11-30, 注意11=11月
LocalTime t2 = LocalTime.of(15, 16, 17); // 15:16:17
LocalDateTime dt2 = LocalDateTime.of(2019, 11, 30, 15, 16, 17);
LocalDateTime dt3 = LocalDateTime.of(d2, t2);

因为严格按照ISO 8601的格式,因此,将字符串转换为LocalDateTime就可以传入标准格式:

1
2
3
LocalDateTime dt = LocalDateTime.parse("2019-11-19T15:16:17");
LocalDate d = LocalDate.parse("2019-11-19");
LocalTime t = LocalTime.parse("15:16:17");

注意ISO 8601规定的日期和时间分隔符是T。标准格式如下:

  • 日期:yyyy-MM-dd
  • 时间:HH:mm:ss
  • 带毫秒的时间:HH:mm:ss.SSS
  • 日期和时间:yyyy-MM-dd’T’HH:mm:ss
  • 带毫秒的日期和时间:yyyy-MM-dd’T’HH:mm:ss.SSS

DateTimeFormatter

如果要自定义输出的格式,或者要把一个非ISO 8601格式的字符串解析成LocalDateTime,可以使用新的DateTimeFormatter

1
2
3
4
5
6
7
8
9
10
11
12
13
14
import java.time.*;
import java.time.format.*;

public class Main {
public static void main(String[] args) {
// 自定义格式化:
DateTimeFormatter dtf = DateTimeFormatter.ofPattern("yyyy/MM/dd HH:mm:ss");
System.out.println(dtf.format(LocalDateTime.now()));

// 用自定义格式解析:
LocalDateTime dt2 = LocalDateTime.parse("2019/11/30 15:16:17", dtf);
System.out.println(dt2);
}
}

LocalDateTime提供了对日期和时间进行加减的非常简单的链式调用:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
import java.time.*;

public class Main {
public static void main(String[] args) {
LocalDateTime dt = LocalDateTime.of(2019, 10, 26, 20, 30, 59);
System.out.println(dt);
// 加5天减3小时:
LocalDateTime dt2 = dt.plusDays(5).minusHours(3);
System.out.println(dt2); // 2019-10-31T17:30:59
// 减1月:
LocalDateTime dt3 = dt2.minusMonths(1);
System.out.println(dt3); // 2019-09-30T17:30:59
}
}

注意到月份加减会自动调整日期,例如从2019-10-31减去1个月得到的结果是2019-09-30,因为9月没有31日。

对日期和时间进行调整则使用withXxx()方法,例如:withHour(15)会把10:11:12变为15:11:12

  • 调整年:withYear()
  • 调整月:withMonth()
  • 调整日:withDayOfMonth()
  • 调整时:withHour()
  • 调整分:withMinute()
  • 调整秒:withSecond()

示例代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
import java.time.*;

public class Main {
public static void main(String[] args) {
LocalDateTime dt = LocalDateTime.of(2019, 10, 26, 20, 30, 59);
System.out.println(dt);
// 日期变为31日:
LocalDateTime dt2 = dt.withDayOfMonth(31);
System.out.println(dt2); // 2019-10-31T20:30:59
// 月份变为9:
LocalDateTime dt3 = dt2.withMonth(9);
System.out.println(dt3); // 2019-09-30T20:30:59
}
}

同样注意到调整月份时,会相应地调整日期,即把2019-10-31的月份调整为9时,日期也自动变为30

实际上,LocalDateTime还有一个通用的with()方法允许我们做更复杂的运算。例如:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import java.time.*;
import java.time.temporal.*;

public class Main {
public static void main(String[] args) {
// 本月第一天0:00时刻:
LocalDateTime firstDay = LocalDate.now().withDayOfMonth(1).atStartOfDay();
System.out.println(firstDay);

// 本月最后1天:
LocalDate lastDay = LocalDate.now().with(TemporalAdjusters.lastDayOfMonth());
System.out.println(lastDay);

// 下月第1天:
LocalDate nextMonthFirstDay = LocalDate.now().with(TemporalAdjusters.firstDayOfNextMonth());
System.out.println(nextMonthFirstDay);

// 本月第1个周一:
LocalDate firstWeekday = LocalDate.now().with(TemporalAdjusters.firstInMonth(DayOfWeek.MONDAY));
System.out.println(firstWeekday);
}
}

对于计算某个月第1个周日这样的问题,新的API可以轻松完成。

要判断两个LocalDateTime的先后,可以使用isBefore()isAfter()方法,对于LocalDateLocalTime类似:

1
2
3
4
5
6
7
8
9
10
11
import java.time.*;

public class Main {
public static void main(String[] args) {
LocalDateTime now = LocalDateTime.now();
LocalDateTime target = LocalDateTime.of(2019, 11, 19, 8, 15, 0);
System.out.println(now.isBefore(target));
System.out.println(LocalDate.now().isBefore(LocalDate.of(2019, 11, 19)));
System.out.println(LocalTime.now().isAfter(LocalTime.parse("08:15:00")));
}
}

注意到LocalDateTime无法与时间戳进行转换,因为LocalDateTime没有时区,无法确定某一时刻。后面我们要介绍的ZonedDateTime相当于LocalDateTime加时区的组合,它具有时区,可以与long表示的时间戳进行转换。

Duration和Period

Duration表示两个时刻之间的时间间隔。另一个类似的Period表示两个日期之间的天数:

1
2
3
4
5
6
7
8
9
10
11
12
13
import java.time.*;

public class Main {
public static void main(String[] args) {
LocalDateTime start = LocalDateTime.of(2019, 11, 19, 8, 15, 0);
LocalDateTime end = LocalDateTime.of(2020, 1, 9, 19, 25, 30);
Duration d = Duration.between(start, end);
System.out.println(d); // PT1235H10M30S

Period p = LocalDate.of(2019, 11, 19).until(LocalDate.of(2020, 1, 9));
System.out.println(p); // P1M21D
}
}

注意到两个LocalDateTime之间的差值使用Duration表示,类似PT1235H10M30S,表示1235小时10分钟30秒。而两个LocalDate之间的差值用Period表示,类似P1M21D,表示1个月21天。

DurationPeriod的表示方法也符合ISO 8601的格式,它以P...T...的形式表示,P...T之间表示日期间隔,T后面表示时间间隔。如果是PT...的格式表示仅有时间间隔。利用ofXxx()或者parse()方法也可以直接创建Duration

1
2
Duration d1 = Duration.ofHours(10); // 10 hours
Duration d2 = Duration.parse("P1DT2H3M"); // 1 day, 2 hours, 3 minutes

有的童鞋可能发现Java 8引入的java.timeAPI。怎么和一个开源的Joda Time很像?难道JDK也开始抄袭开源了?其实正是因为开源的Joda Time设计很好,应用广泛,所以JDK团队邀请Joda Time的作者Stephen Colebourne共同设计了java.timeAPI。

小结

Java 8引入了新的日期和时间API,它们是不变类,默认按ISO 8601标准格式化和解析;

使用LocalDateTime可以非常方便地对日期和时间进行加减,或者调整日期和时间,它总是返回新对象;

使用isBefore()isAfter()可以判断日期和时间的先后;

使用DurationPeriod可以表示两个日期和时间的“区间间隔”。

LocalDateTime总是表示本地日期和时间,要表示一个带时区的日期和时间,我们就需要ZonedDateTime

可以简单地把ZonedDateTime理解成LocalDateTimeZoneIdZoneIdjava.time引入的新的时区类,注意和旧的java.util.TimeZone区别。

要创建一个ZonedDateTime对象,有以下几种方法,一种是通过now()方法返回当前时间:

1
2
3
4
5
6
7
8
9
10
import java.time.*;

public class Main {
public static void main(String[] args) {
ZonedDateTime zbj = ZonedDateTime.now(); // 默认时区
ZonedDateTime zny = ZonedDateTime.now(ZoneId.of("America/New_York")); // 用指定时区获取当前时间
System.out.println(zbj);
System.out.println(zny);
}
}

观察打印的两个ZonedDateTime,发现它们时区不同,但表示的时间都是同一时刻(毫秒数不同是执行语句时的时间差):

1
2
2019-09-15T20:58:18.786182+08:00[Asia/Shanghai]
2019-09-15T08:58:18.788860-04:00[America/New_York]

另一种方式是通过给一个LocalDateTime附加一个ZoneId,就可以变成ZonedDateTime

1
2
3
4
5
6
7
8
9
10
11
import java.time.*;

public class Main {
public static void main(String[] args) {
LocalDateTime ldt = LocalDateTime.of(2019, 9, 15, 15, 16, 17);
ZonedDateTime zbj = ldt.atZone(ZoneId.systemDefault());
ZonedDateTime zny = ldt.atZone(ZoneId.of("America/New_York"));
System.out.println(zbj);
System.out.println(zny);
}
}

以这种方式创建的ZonedDateTime,它的日期和时间与LocalDateTime相同,但附加的时区不同,因此是两个不同的时刻:

1
2
2019-09-15T15:16:17+08:00[Asia/Shanghai]
2019-09-15T15:16:17-04:00[America/New_York]

时区转换

要转换时区,首先我们需要有一个ZonedDateTime对象,然后,通过withZoneSameInstant()将关联时区转换到另一个时区,转换后日期和时间都会相应调整。

下面的代码演示了如何将北京时间转换为纽约时间:

1
2
3
4
5
6
7
8
9
10
11
12
import java.time.*;

public class Main {
public static void main(String[] args) {
// 以中国时区获取当前时间:
ZonedDateTime zbj = ZonedDateTime.now(ZoneId.of("Asia/Shanghai"));
// 转换为纽约时间:
ZonedDateTime zny = zbj.withZoneSameInstant(ZoneId.of("America/New_York"));
System.out.println(zbj);
System.out.println(zny);
}
}

要特别注意,时区转换的时候,由于夏令时的存在,不同的日期转换的结果很可能是不同的。这是北京时间9月15日的转换结果:

1
2
2019-09-15T21:05:50.187697+08:00[Asia/Shanghai]
2019-09-15T09:05:50.187697-04:00[America/New_York]

这是北京时间11月15日的转换结果:

1
2
2019-11-15T21:05:50.187697+08:00[Asia/Shanghai]
2019-11-15T08:05:50.187697-05:00[America/New_York]

两次转换后的纽约时间有1小时的夏令时时差。

注意

涉及到时区时,千万不要自己计算时差,否则难以正确处理夏令时。

有了ZonedDateTime,将其转换为本地时间就非常简单:

1
2
ZonedDateTime zdt = ...
LocalDateTime ldt = zdt.toLocalDateTime();

转换为LocalDateTime时,直接丢弃了时区信息。

练习

某航线从北京飞到纽约需要13小时20分钟,请根据北京起飞日期和时间计算到达纽约的当地日期和时间。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import java.time.*;

public class Main {
public static void main(String[] args) {
LocalDateTime departureAtBeijing = LocalDateTime.of(2019, 9, 15, 13, 0, 0);
int hours = 13;
int minutes = 20;
LocalDateTime arrivalAtNewYork = calculateArrivalAtNY(departureAtBeijing, hours, minutes);
System.out.println(departureAtBeijing + " -> " + arrivalAtNewYork);
// test:
if (!LocalDateTime.of(2019, 10, 15, 14, 20, 0)
.equals(calculateArrivalAtNY(LocalDateTime.of(2019, 10, 15, 13, 0, 0), 13, 20))) {
System.err.println("测试失败!");
} else if (!LocalDateTime.of(2019, 11, 15, 13, 20, 0)
.equals(calculateArrivalAtNY(LocalDateTime.of(2019, 11, 15, 13, 0, 0), 13, 20))) {
System.err.println("测试失败!");
}
}

static LocalDateTime calculateArrivalAtNY(LocalDateTime bj, int h, int m) {
return bj;
}
}

提示:ZonedDateTime仍然提供了plusDays()等加减操作。

下载练习

小结

ZonedDateTime是带时区的日期和时间,可用于时区转换;

ZonedDateTimeLocalDateTime可以相互转换。

DateTimeFormatter

使用旧的Date对象时,我们用SimpleDateFormat进行格式化显示。使用新的LocalDateTimeZonedDateTime时,我们要进行格式化显示,就要使用DateTimeFormatter

SimpleDateFormat不同的是,DateTimeFormatter不但是不变对象,它还是线程安全的。线程的概念我们会在后面涉及到。现在我们只需要记住:因为SimpleDateFormat不是线程安全的,使用的时候,只能在方法内部创建新的局部变量。而DateTimeFormatter可以只创建一个实例,到处引用。

创建DateTimeFormatter时,我们仍然通过传入格式化字符串实现:

1
DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm");

格式化字符串的使用方式与SimpleDateFormat完全一致。

另一种创建DateTimeFormatter的方法是,传入格式化字符串时,同时指定Locale

1
DateTimeFormatter formatter = DateTimeFormatter.ofPattern("E, yyyy-MMMM-dd HH:mm", Locale.US);

这种方式可以按照Locale默认习惯格式化。我们来看实际效果:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import java.time.*;
import java.time.format.*;
import java.util.Locale;

public class Main {
public static void main(String[] args) {
ZonedDateTime zdt = ZonedDateTime.now();
var formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd'T'HH:mm ZZZZ");
System.out.println(formatter.format(zdt));

var zhFormatter = DateTimeFormatter.ofPattern("yyyy MMM dd EE HH:mm", Locale.CHINA);
System.out.println(zhFormatter.format(zdt));

var usFormatter = DateTimeFormatter.ofPattern("E, MMMM/dd/yyyy HH:mm", Locale.US);
System.out.println(usFormatter.format(zdt));
}
}

在格式化字符串中,如果需要输出固定字符,可以用'xxx'表示。

运行上述代码,分别以默认方式、中国地区和美国地区对当前时间进行显示,结果如下:

1
2
3
2019-09-15T23:16 GMT+08:00
2019 9月 15 周日 23:16
Sun, September/15/2019 23:16

当我们直接调用System.out.println()对一个ZonedDateTime或者LocalDateTime实例进行打印的时候,实际上,调用的是它们的toString()方法,默认的toString()方法显示的字符串就是按照ISO 8601格式显示的,我们可以通过DateTimeFormatter预定义的几个静态变量来引用:

1
2
3
var ldt = LocalDateTime.now();
System.out.println(DateTimeFormatter.ISO_DATE.format(ldt));
System.out.println(DateTimeFormatter.ISO_DATE_TIME.format(ldt));

得到的输出和toString()类似:

1
2
2019-09-15
2019-09-15T23:16:51.56217

小结

ZonedDateTimeLocalDateTime进行格式化,需要使用DateTimeFormatter类;

DateTimeFormatter可以通过格式化字符串和Locale对日期和时间进行定制输出。



Instant

我们已经讲过,计算机存储的当前时间,本质上只是一个不断递增的整数。Java提供的System.currentTimeMillis()返回的就是以毫秒表示的当前时间戳。

这个当前时间戳在java.time中以Instant类型表示,我们用Instant.now()获取当前时间戳,效果和System.currentTimeMillis()类似:

1
2
3
4
5
6
7
8
9
import java.time.*;

public class Main {
public static void main(String[] args) {
Instant now = Instant.now();
System.out.println(now.getEpochSecond()); // 秒
System.out.println(now.toEpochMilli()); // 毫秒
}
}

打印的结果类似:

1
2
1568568760
1568568760316

实际上,Instant内部只有两个核心字段:

1
2
3
4
public final class Instant implements ... {
private final long seconds;
private final int nanos;
}

一个是以秒为单位的时间戳,一个是更精确的纳秒精度。它和System.currentTimeMillis()返回的long相比,只是多了更高精度的纳秒。

既然Instant就是时间戳,那么,给它附加上一个时区,就可以创建出ZonedDateTime

1
2
3
4
// 以指定时间戳创建Instant:
Instant ins = Instant.ofEpochSecond(1568568760);
ZonedDateTime zdt = ins.atZone(ZoneId.systemDefault());
System.out.println(zdt); // 2019-09-16T01:32:40+08:00[Asia/Shanghai]

可见,对于某一个时间戳,给它关联上指定的ZoneId,就得到了ZonedDateTime,继而可以获得了对应时区的LocalDateTime

所以,LocalDateTimeZoneIdInstantZonedDateTimelong都可以互相转换:

1
2
3
4
5
6
7
8
9
10
11
12
┌─────────────┐
│LocalDateTime│────┐
└─────────────┘ │ ┌─────────────┐
├───▶│ZonedDateTime│
┌─────────────┐ │ └─────────────┘
│ ZoneId │────┘ ▲
└─────────────┘ ┌─────────┴─────────┐
│ │
▼ ▼
┌─────────────┐ ┌─────────────┐
│ Instant │◀───▶│ long │
└─────────────┘ └─────────────┘

转换的时候,只需要留意long类型以毫秒还是秒为单位即可。

小结

Instant表示高精度时间戳,它可以和ZonedDateTime以及long互相转换。



由于Java提供了新旧两套日期和时间的API,除非涉及到遗留代码,否则我们应该坚持使用新的API。

如果需要与遗留代码打交道,如何在新旧API之间互相转换呢?

旧API转新API

如果要把旧式的DateCalendar转换为新API对象,可以通过toInstant()方法转换为Instant对象,再继续转换为ZonedDateTime

1
2
3
4
5
6
7
// Date -> Instant:
Instant ins1 = new Date().toInstant();

// Calendar -> Instant -> ZonedDateTime:
Calendar calendar = Calendar.getInstance();
Instant ins2 = calendar.toInstant();
ZonedDateTime zdt = ins2.atZone(calendar.getTimeZone().toZoneId());

从上面的代码还可以看到,旧的TimeZone提供了一个toZoneId(),可以把自己变成新的ZoneId

新API转旧API

如果要把新的ZonedDateTime转换为旧的API对象,只能借助long型时间戳做一个“中转”:

1
2
3
4
5
6
7
8
9
10
11
12
// ZonedDateTime -> long:
ZonedDateTime zdt = ZonedDateTime.now();
long ts = zdt.toEpochSecond() * 1000;

// long -> Date:
Date date = new Date(ts);

// long -> Calendar:
Calendar calendar = Calendar.getInstance();
calendar.clear();
calendar.setTimeZone(TimeZone.getTimeZone(zdt.getZone().getId()));
calendar.setTimeInMillis(zdt.toEpochSecond() * 1000);

从上面的代码还可以看到,新的ZoneId转换为旧的TimeZone,需要借助ZoneId.getId()返回的String完成。

在数据库中存储日期和时间

除了旧式的java.util.Date,我们还可以找到另一个java.sql.Date,它继承自java.util.Date,但会自动忽略所有时间相关信息。这个奇葩的设计原因要追溯到数据库的日期与时间类型。

在数据库中,也存在几种日期和时间类型:

  • DATETIME:表示日期和时间;
  • DATE:仅表示日期;
  • TIME:仅表示时间;
  • TIMESTAMP:和DATETIME类似,但是数据库会在创建或者更新记录的时候同时修改TIMESTAMP

在使用Java程序操作数据库时,我们需要把数据库类型与Java类型映射起来。下表是数据库类型与Java新旧API的映射关系:

数据库 对应Java类(旧) 对应Java类(新)
DATETIME java.util.Date LocalDateTime
DATE java.sql.Date LocalDate
TIME java.sql.Time LocalTime
TIMESTAMP java.sql.Timestamp LocalDateTime

实际上,在数据库中,我们需要存储的最常用的是时刻(Instant),因为有了时刻信息,就可以根据用户自己选择的时区,显示出正确的本地时间。所以,最好的方法是直接用长整数long表示,在数据库中存储为BIGINT类型。

通过存储一个long型时间戳,我们可以编写一个timestampToString()的方法,非常简单地为不同用户以不同的偏好来显示不同的本地时间:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import java.time.*;
import java.time.format.*;
import java.util.Locale;

public class Main {
public static void main(String[] args) {
long ts = 1574208900000L;
System.out.println(timestampToString(ts, Locale.CHINA, "Asia/Shanghai"));
System.out.println(timestampToString(ts, Locale.US, "America/New_York"));
}

static String timestampToString(long epochMilli, Locale lo, String zoneId) {
Instant ins = Instant.ofEpochMilli(epochMilli);
DateTimeFormatter f = DateTimeFormatter.ofLocalizedDateTime(FormatStyle.MEDIUM, FormatStyle.SHORT);
return f.withLocale(lo).format(ZonedDateTime.ofInstant(ins, ZoneId.of(zoneId)));
}
}

对上述方法进行调用,结果如下:

1
2
2019年11月20日 上午8:15
Nov 19, 2019, 7:15 PM

小结

处理日期和时间时,尽量使用新的java.time包;

在数据库中存储时间戳时,尽量使用long型时间戳,它具有省空间,效率高,不依赖数据库的优点。

留言與分享

JAVA-IO

分類 编程语言, Java

IO是指Input/Output,即输入和输出。以内存为中心:

  • Input指从外部读入数据到内存,例如,把文件从磁盘读取到内存,从网络读取数据到内存等等。
  • Output指把数据从内存输出到外部,例如,把数据从内存写入到文件,把数据从内存输出到网络等等。

为什么要把数据读到内存才能处理这些数据?因为代码是在内存中运行的,数据也必须读到内存,最终的表示方式无非是byte数组,字符串等,都必须存放在内存里。

从Java代码来看,输入实际上就是从外部,例如,硬盘上的某个文件,把内容读到内存,并且以Java提供的某种数据类型表示,例如,byte[]String,这样,后续代码才能处理这些数据。

因为内存有“易失性”的特点,所以必须把处理后的数据以某种方式输出,例如,写入到文件。Output实际上就是把Java表示的数据格式,例如,byte[]String等输出到某个地方。

IO流是一种顺序读写数据的模式,它的特点是单向流动。数据类似自来水一样在水管中流动,所以我们把它称为IO流。

java-io

InputStream / OutputStream

IO流以byte(字节)为最小单位,因此也称为字节流。例如,我们要从磁盘读入一个文件,包含6个字节,就相当于读入了6个字节的数据:

1
2
3
4
5
6
7
8
9
10
11
12
13
╔═══════════╗
║ Memory ║
╚═══════════╝

│0x48
│0x65
│0x6c
│0x6c
│0x6f
│0x21
╔═══════════╗
║ Hard Disk ║
╚═══════════╝

这6个字节是按顺序读入的,所以是输入字节流。

反过来,我们把6个字节从内存写入磁盘文件,就是输出字节流:

1
2
3
4
5
6
7
8
9
10
11
12
13
╔═══════════╗
║ Memory ║
╚═══════════╝
│0x21
│0x6f
│0x6c
│0x6c
│0x65
│0x48

╔═══════════╗
║ Hard Disk ║
╚═══════════╝

在Java中,InputStream代表输入字节流,OuputStream代表输出字节流,这是最基本的两种IO流。

Reader / Writer

如果我们需要读写的是字符,并且字符不全是单字节表示的ASCII字符,那么,按照char来读写显然更方便,这种流称为字符流

Java提供了ReaderWriter表示字符流,字符流传输的最小数据单位是char

例如,我们把char[]数组Hi你好这4个字符用Writer字符流写入文件,并且使用UTF-8编码,得到的最终文件内容是8个字节,英文字符Hi各占一个字节,中文字符你好各占3个字节:

1
2
3
4
0x48
0x69
0xe4bda0
0xe5a5bd

反过来,我们用Reader读取以UTF-8编码的这8个字节,会从Reader中得到Hi你好这4个字符。

因此,ReaderWriter本质上是一个能自动编解码的InputStreamOutputStream

使用Reader,数据源虽然是字节,但我们读入的数据都是char类型的字符,原因是Reader内部把读入的byte做了解码,转换成了char。使用InputStream,我们读入的数据和原始二进制数据一模一样,是byte[]数组,但是我们可以自己把二进制byte[]数组按照某种编码转换为字符串。究竟使用Reader还是InputStream,要取决于具体的使用场景。如果数据源不是文本,就只能使用InputStream,如果数据源是文本,使用Reader更方便一些。WriterOutputStream是类似的。

同步和异步

同步IO是指,读写IO时代码必须等待数据返回后才继续执行后续代码,它的优点是代码编写简单,缺点是CPU执行效率低。

而异步IO是指,读写IO时仅发出请求,然后立刻执行后续代码,它的优点是CPU执行效率高,缺点是代码编写复杂。

Java标准库的包java.io提供了同步IO,而java.nio则是异步IO。上面我们讨论的InputStreamOutputStreamReaderWriter都是同步IO的抽象类,对应的具体实现类,以文件为例,有FileInputStreamFileOutputStreamFileReaderFileWriter

本节我们只讨论Java的同步IO,即输入/输出流的IO模型。

小结

IO流是一种流式的数据输入/输出模型:

  • 二进制数据以byte为最小单位在InputStream/OutputStream中单向流动;
  • 字符数据以char为最小单位在Reader/Writer中单向流动。

Java标准库的java.io包提供了同步IO功能:

  • 字节流接口:InputStream/OutputStream
  • 字符流接口:Reader/Writer

在计算机系统中,文件是非常重要的存储方式。Java的标准库java.io提供了File对象来操作文件和目录。

要构造一个File对象,需要传入文件路径:

1
2
3
4
5
6
7
8
import java.io.*;

public class Main {
public static void main(String[] args) {
File f = new File("C:\\Windows\\notepad.exe");
System.out.println(f);
}
}

构造File对象时,既可以传入绝对路径,也可以传入相对路径。绝对路径是以根目录开头的完整路径,例如:

1
File f = new File("C:\\Windows\\notepad.exe");

注意Windows平台使用\作为路径分隔符,在Java字符串中需要用\\表示一个\。Linux平台使用/作为路径分隔符:

1
File f = new File("/usr/bin/javac");

传入相对路径时,相对路径前面加上当前目录就是绝对路径:

1
2
3
4
// 假设当前目录是C:\Docs
File f1 = new File("sub\\javac"); // 绝对路径是C:\Docs\sub\javac
File f3 = new File(".\\sub\\javac"); // 绝对路径是C:\Docs\sub\javac
File f3 = new File("..\\sub\\javac"); // 绝对路径是C:\sub\javac

可以用.表示当前目录,..表示上级目录。

File对象有3种形式表示的路径,一种是getPath(),返回构造方法传入的路径,一种是getAbsolutePath(),返回绝对路径,一种是getCanonicalPath,它和绝对路径类似,但是返回的是规范路径。

什么是规范路径?我们看以下代码:

1
2
3
4
5
6
7
8
9
10
import java.io.*;

public class Main {
public static void main(String[] args) throws IOException {
File f = new File("..");
System.out.println(f.getPath());
System.out.println(f.getAbsolutePath());
System.out.println(f.getCanonicalPath());
}
}

绝对路径可以表示成C:\Windows\System32\..\notepad.exe,而规范路径就是把...转换成标准的绝对路径后的路径:C:\Windows\notepad.exe

因为Windows和Linux的路径分隔符不同,File对象有一个静态变量用于表示当前平台的系统分隔符:

1
System.out.println(File.separator); // 根据当前平台打印"\"或"/"

文件和目录

File对象既可以表示文件,也可以表示目录。特别要注意的是,构造一个File对象,即使传入的文件或目录不存在,代码也不会出错,因为构造一个File对象,并不会导致任何磁盘操作。只有当我们调用File对象的某些方法的时候,才真正进行磁盘操作。

例如,调用isFile(),判断该File对象是否是一个已存在的文件,调用isDirectory(),判断该File对象是否是一个已存在的目录:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import java.io.*;

public class Main {
public static void main(String[] args) throws IOException {
File f1 = new File("C:\\Windows");
File f2 = new File("C:\\Windows\\notepad.exe");
File f3 = new File("C:\\Windows\\nothing");
System.out.println(f1.isFile());
System.out.println(f1.isDirectory());
System.out.println(f2.isFile());
System.out.println(f2.isDirectory());
System.out.println(f3.isFile());
System.out.println(f3.isDirectory());
}
}

File对象获取到一个文件时,还可以进一步判断文件的权限和大小:

  • boolean canRead():是否可读;
  • boolean canWrite():是否可写;
  • boolean canExecute():是否可执行;
  • long length():文件字节大小。

对目录而言,是否可执行表示能否列出它包含的文件和子目录。

创建和删除文件

当File对象表示一个文件时,可以通过createNewFile()创建一个新文件,用delete()删除该文件:

1
2
3
4
5
6
7
8
File file = new File("/path/to/file");
if (file.createNewFile()) {
// 文件创建成功:
// TODO:
if (file.delete()) {
// 删除文件成功:
}
}

有些时候,程序需要读写一些临时文件,File对象提供了createTempFile()来创建一个临时文件,以及deleteOnExit()在JVM退出时自动删除该文件。

1
2
3
4
5
6
7
8
9
10
import java.io.*;

public class Main {
public static void main(String[] args) throws IOException {
File f = File.createTempFile("tmp-", ".txt"); // 提供临时文件的前缀和后缀
f.deleteOnExit(); // JVM退出时自动删除
System.out.println(f.isFile());
System.out.println(f.getAbsolutePath());
}
}

遍历文件和目录

当File对象表示一个目录时,可以使用list()listFiles()列出目录下的文件和子目录名。listFiles()提供了一系列重载方法,可以过滤不想要的文件和目录:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import java.io.*;

public class Main {
public static void main(String[] args) throws IOException {
File f = new File("C:\\Windows");
File[] fs1 = f.listFiles(); // 列出所有文件和子目录
printFiles(fs1);
File[] fs2 = f.listFiles(new FilenameFilter() { // 仅列出.exe文件
public boolean accept(File dir, String name) {
return name.endsWith(".exe"); // 返回true表示接受该文件
}
});
printFiles(fs2);
}

static void printFiles(File[] files) {
System.out.println("==========");
if (files != null) {
for (File f : files) {
System.out.println(f);
}
}
System.out.println("==========");
}
}

和文件操作类似,File对象如果表示一个目录,可以通过以下方法创建和删除目录:

  • boolean mkdir():创建当前File对象表示的目录;
  • boolean mkdirs():创建当前File对象表示的目录,并在必要时将不存在的父目录也创建出来;
  • boolean delete():删除当前File对象表示的目录,当前目录必须为空才能删除成功。

Path

Java标准库还提供了一个Path对象,它位于java.nio.file包。Path对象和File对象类似,但操作更加简单:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import java.io.*;
import java.nio.file.*;

public class Main {
public static void main(String[] args) throws IOException {
Path p1 = Paths.get(".", "project", "study"); // 构造一个Path对象
System.out.println(p1);
Path p2 = p1.toAbsolutePath(); // 转换为绝对路径
System.out.println(p2);
Path p3 = p2.normalize(); // 转换为规范路径
System.out.println(p3);
File f = p3.toFile(); // 转换为File对象
System.out.println(f);
for (Path p : Paths.get("..").toAbsolutePath()) { // 可以直接遍历Path
System.out.println(" " + p);
}
}
}

如果需要对目录进行复杂的拼接、遍历等操作,使用Path对象更方便。

练习

请利用File对象列出指定目录下的所有子目录和文件,并按层次打印。

例如,输出:

1
2
3
4
5
6
7
8
Documents/
word/
1.docx
2.docx
work/
abc.doc
ppt/
other/

如果不指定参数,则使用当前目录,如果指定参数,则使用指定目录。

下载练习

小结

Java标准库的java.io.File对象表示一个文件或者目录:

  • 创建File对象本身不涉及IO操作;
  • 可以获取路径/绝对路径/规范路径:getPath()/getAbsolutePath()/getCanonicalPath()
  • 可以获取目录的文件和子目录:list()/listFiles()
  • 可以创建或删除文件和目录。

InputStream就是Java标准库提供的最基本的输入流。它位于java.io这个包里。java.io包提供了所有同步IO的功能。

要特别注意的一点是,InputStream并不是一个接口,而是一个抽象类,它是所有输入流的超类。这个抽象类定义的一个最重要的方法就是int read(),签名如下:

1
public abstract int read() throws IOException;

这个方法会读取输入流的下一个字节,并返回字节表示的int值(0~255)。如果已读到末尾,返回-1表示不能继续读取了。

FileInputStreamInputStream的一个子类。顾名思义,FileInputStream就是从文件流中读取数据。下面的代码演示了如何完整地读取一个FileInputStream的所有字节:

1
2
3
4
5
6
7
8
9
10
11
12
public void readFile() throws IOException {
// 创建一个FileInputStream对象:
InputStream input = new FileInputStream("src/readme.txt");
for (;;) {
int n = input.read(); // 反复调用read()方法,直到返回-1
if (n == -1) {
break;
}
System.out.println(n); // 打印byte的值
}
input.close(); // 关闭流
}

在计算机中,类似文件、网络端口这些资源,都是由操作系统统一管理的。应用程序在运行的过程中,如果打开了一个文件进行读写,完成后要及时地关闭,以便让操作系统把资源释放掉,否则,应用程序占用的资源会越来越多,不但白白占用内存,还会影响其他应用程序的运行。

InputStreamOutputStream都是通过close()方法来关闭流。关闭流就会释放对应的底层资源。

我们还要注意到在读取或写入IO流的过程中,可能会发生错误,例如,文件不存在导致无法读取,没有写权限导致写入失败,等等,这些底层错误由Java虚拟机自动封装成IOException异常并抛出。因此,所有与IO操作相关的代码都必须正确处理IOException

仔细观察上面的代码,会发现一个潜在的问题:如果读取过程中发生了IO错误,InputStream就没法正确地关闭,资源也就没法及时释放。

因此,我们需要用try ... finally来保证InputStream在无论是否发生IO错误的时候都能够正确地关闭:

1
2
3
4
5
6
7
8
9
10
11
12
public void readFile() throws IOException {
InputStream input = null;
try {
input = new FileInputStream("src/readme.txt");
int n;
while ((n = input.read()) != -1) { // 利用while同时读取并判断
System.out.println(n);
}
} finally {
if (input != null) { input.close(); }
}
}

try ... finally来编写上述代码会感觉比较复杂,更好的写法是利用Java 7引入的新的try(resource)的语法,只需要编写try语句,让编译器自动为我们关闭资源。推荐的写法如下:

1
2
3
4
5
6
7
8
public void readFile() throws IOException {
try (InputStream input = new FileInputStream("src/readme.txt")) {
int n;
while ((n = input.read()) != -1) {
System.out.println(n);
}
} // 编译器在此自动为我们写入finally并调用close()
}

实际上,编译器并不会特别地为InputStream加上自动关闭。编译器只看try(resource = ...)中的对象是否实现了java.lang.AutoCloseable接口,如果实现了,就自动加上finally语句并调用close()方法。InputStreamOutputStream都实现了这个接口,因此,都可以用在try(resource)中。

缓冲

在读取流的时候,一次读取一个字节并不是最高效的方法。很多流支持一次性读取多个字节到缓冲区,对于文件和网络流来说,利用缓冲区一次性读取多个字节效率往往要高很多。InputStream提供了两个重载方法来支持读取多个字节:

  • int read(byte[] b):读取若干字节并填充到byte[]数组,返回读取的字节数
  • int read(byte[] b, int off, int len):指定byte[]数组的偏移量和最大填充数

利用上述方法一次读取多个字节时,需要先定义一个byte[]数组作为缓冲区,read()方法会尽可能多地读取字节到缓冲区, 但不会超过缓冲区的大小。read()方法的返回值不再是字节的int值,而是返回实际读取了多少个字节。如果返回-1,表示没有更多的数据了。

利用缓冲区一次读取多个字节的代码如下:

1
2
3
4
5
6
7
8
9
10
public void readFile() throws IOException {
try (InputStream input = new FileInputStream("src/readme.txt")) {
// 定义1000个字节大小的缓冲区:
byte[] buffer = new byte[1000];
int n;
while ((n = input.read(buffer)) != -1) { // 读取到缓冲区
System.out.println("read " + n + " bytes.");
}
}
}

阻塞

在调用InputStreamread()方法读取数据时,我们说read()方法是阻塞(Blocking)的。它的意思是,对于下面的代码:

1
2
3
int n;
n = input.read(); // 必须等待read()方法返回才能执行下一行代码
int m = n;

执行到第二行代码时,必须等read()方法返回后才能继续。因为读取IO流相比执行普通代码,速度会慢很多,因此,无法确定read()方法调用到底要花费多长时间。

InputStream实现类

FileInputStream可以从文件获取输入流,这是InputStream常用的一个实现类。此外,ByteArrayInputStream可以在内存中模拟一个InputStream

1
2
3
4
5
6
7
8
9
10
11
12
13
import java.io.*;

public class Main {
public static void main(String[] args) throws IOException {
byte[] data = { 72, 101, 108, 108, 111, 33 };
try (InputStream input = new ByteArrayInputStream(data)) {
int n;
while ((n = input.read()) != -1) {
System.out.println((char)n);
}
}
}
}

ByteArrayInputStream实际上是把一个byte[]数组在内存中变成一个InputStream,虽然实际应用不多,但测试的时候,可以用它来构造一个InputStream

举个例子:我们想从文件中读取所有字节,并转换成char然后拼成一个字符串,可以这么写:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
public class Main {
public static void main(String[] args) throws IOException {
String s;
try (InputStream input = new FileInputStream("C:\\test\\README.txt")) {
int n;
StringBuilder sb = new StringBuilder();
while ((n = input.read()) != -1) {
sb.append((char) n);
}
s = sb.toString();
}
System.out.println(s);
}
}

要测试上面的程序,就真的需要在本地硬盘上放一个真实的文本文件。如果我们把代码稍微改造一下,提取一个readAsString()的方法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
public class Main {
public static void main(String[] args) throws IOException {
String s;
try (InputStream input = new FileInputStream("C:\\test\\README.txt")) {
s = readAsString(input);
}
System.out.println(s);
}

public static String readAsString(InputStream input) throws IOException {
int n;
StringBuilder sb = new StringBuilder();
while ((n = input.read()) != -1) {
sb.append((char) n);
}
return sb.toString();
}
}

对这个String readAsString(InputStream input)方法进行测试就相当简单,因为不一定要传入一个真的FileInputStream

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import java.io.*;

public class Main {
public static void main(String[] args) throws IOException {
byte[] data = { 72, 101, 108, 108, 111, 33 };
try (InputStream input = new ByteArrayInputStream(data)) {
String s = readAsString(input);
System.out.println(s);
}
}

public static String readAsString(InputStream input) throws IOException {
int n;
StringBuilder sb = new StringBuilder();
while ((n = input.read()) != -1) {
sb.append((char) n);
}
return sb.toString();
}
}

这就是面向抽象编程原则的应用:接受InputStream抽象类型,而不是具体的FileInputStream类型,从而使得代码可以处理InputStream的任意实现类。

小结

Java标准库的java.io.InputStream定义了所有输入流的超类:

  • FileInputStream实现了文件流输入;
  • ByteArrayInputStream在内存中模拟一个字节流输入。

总是使用try(resource)来保证InputStream正确关闭。

InputStream相反,OutputStream是Java标准库提供的最基本的输出流。

InputStream类似,OutputStream也是抽象类,它是所有输出流的超类。这个抽象类定义的一个最重要的方法就是void write(int b),签名如下:

1
public abstract void write(int b) throws IOException;

这个方法会写入一个字节到输出流。要注意的是,虽然传入的是int参数,但只会写入一个字节,即只写入int最低8位表示字节的部分(相当于b & 0xff)。

InputStream类似,OutputStream也提供了close()方法关闭输出流,以便释放系统资源。要特别注意:OutputStream还提供了一个flush()方法,它的目的是将缓冲区的内容真正输出到目的地。

为什么要有flush()?因为向磁盘、网络写入数据的时候,出于效率的考虑,操作系统并不是输出一个字节就立刻写入到文件或者发送到网络,而是把输出的字节先放到内存的一个缓冲区里(本质上就是一个byte[]数组),等到缓冲区写满了,再一次性写入文件或者网络。对于很多IO设备来说,一次写一个字节和一次写1000个字节,花费的时间几乎是完全一样的,所以OutputStream有个flush()方法,能强制把缓冲区内容输出。

通常情况下,我们不需要调用这个flush()方法,因为缓冲区写满了OutputStream会自动调用它,并且,在调用close()方法关闭OutputStream之前,也会自动调用flush()方法。

但是,在某些情况下,我们必须手动调用flush()方法。举个栗子:

小明正在开发一款在线聊天软件,当用户输入一句话后,就通过OutputStreamwrite()方法写入网络流。小明测试的时候发现,发送方输入后,接收方根本收不到任何信息,怎么回事?

原因就在于写入网络流是先写入内存缓冲区,等缓冲区满了才会一次性发送到网络。如果缓冲区大小是4K,则发送方要敲几千个字符后,操作系统才会把缓冲区的内容发送出去,这个时候,接收方会一次性收到大量消息。

解决办法就是每输入一句话后,立刻调用flush(),不管当前缓冲区是否已满,强迫操作系统把缓冲区的内容立刻发送出去。

实际上,InputStream也有缓冲区。例如,从FileInputStream读取一个字节时,操作系统往往会一次性读取若干字节到缓冲区,并维护一个指针指向未读的缓冲区。然后,每次我们调用int read()读取下一个字节时,可以直接返回缓冲区的下一个字节,避免每次读一个字节都导致IO操作。当缓冲区全部读完后继续调用read(),则会触发操作系统的下一次读取并再次填满缓冲区。

FileOutputStream

我们以FileOutputStream为例,演示如何将若干个字节写入文件流:

1
2
3
4
5
6
7
8
9
public void writeFile() throws IOException {
OutputStream output = new FileOutputStream("out/readme.txt");
output.write(72); // H
output.write(101); // e
output.write(108); // l
output.write(108); // l
output.write(111); // o
output.close();
}

每次写入一个字节非常麻烦,更常见的方法是一次性写入若干字节。这时,可以用OutputStream提供的重载方法void write(byte[])来实现:

1
2
3
4
5
public void writeFile() throws IOException {
OutputStream output = new FileOutputStream("out/readme.txt");
output.write("Hello".getBytes("UTF-8")); // Hello
output.close();
}

InputStream一样,上述代码没有考虑到在发生异常的情况下如何正确地关闭资源。写入过程也会经常发生IO错误,例如,磁盘已满,无权限写入等等。我们需要用try(resource)来保证OutputStream在无论是否发生IO错误的时候都能够正确地关闭:

1
2
3
4
5
public void writeFile() throws IOException {
try (OutputStream output = new FileOutputStream("out/readme.txt")) {
output.write("Hello".getBytes("UTF-8")); // Hello
} // 编译器在此自动为我们写入finally并调用close()
}

阻塞

InputStream一样,OutputStreamwrite()方法也是阻塞的。

OutputStream实现类

FileOutputStream可以从文件获取输出流,这是OutputStream常用的一个实现类。此外,ByteArrayOutputStream可以在内存中模拟一个OutputStream

1
2
3
4
5
6
7
8
9
10
11
12
13
import java.io.*;

public class Main {
public static void main(String[] args) throws IOException {
byte[] data;
try (ByteArrayOutputStream output = new ByteArrayOutputStream()) {
output.write("Hello ".getBytes("UTF-8"));
output.write("world!".getBytes("UTF-8"));
data = output.toByteArray();
}
System.out.println(new String(data, "UTF-8"));
}
}

ByteArrayOutputStream实际上是把一个byte[]数组在内存中变成一个OutputStream,虽然实际应用不多,但测试的时候,可以用它来构造一个OutputStream

同时操作多个AutoCloseable资源时,在try(resource) { ... }语句中可以同时写出多个资源,用;隔开。例如,同时读写两个文件:

1
2
3
4
5
6
// 读取input.txt,写入output.txt:
try (InputStream input = new FileInputStream("input.txt");
OutputStream output = new FileOutputStream("output.txt"))
{
input.transferTo(output); // transferTo的作用是?
}

练习

请利用InputStreamOutputStream,编写一个复制文件的程序,它可以带参数运行:

1
java CopyFile.java source.txt copy.txt

下载练习

小结

Java标准库的java.io.OutputStream定义了所有输出流的超类:

  • FileOutputStream实现了文件流输出;
  • ByteArrayOutputStream在内存中模拟一个字节流输出。

某些情况下需要手动调用OutputStreamflush()方法来强制输出缓冲区;

总是使用try(resource)来保证OutputStream正确关闭。

Java的IO标准库提供的InputStream根据来源可以包括:

  • FileInputStream:从文件读取数据,是最终数据源;
  • ServletInputStream:从HTTP请求读取数据,是最终数据源;
  • Socket.getInputStream():从TCP连接读取数据,是最终数据源;

如果我们要给FileInputStream添加缓冲功能,则可以从FileInputStream派生一个类:

1
BufferedFileInputStream extends FileInputStream {}

如果要给FileInputStream添加计算签名的功能,类似的,也可以从FileInputStream派生一个类:

1
DigestFileInputStream extends FileInputStream {}

如果要给FileInputStream添加加密/解密功能,还是可以从FileInputStream派生一个类:

1
CipherFileInputStream extends FileInputStream {}

如果要给FileInputStream添加缓冲和签名的功能,那么我们还需要派生BufferedDigestFileInputStream。如果要给FileInputStream添加缓冲和加解密的功能,则需要派生BufferedCipherFileInputStream

我们发现,给FileInputStream添加3种功能,至少需要3个子类。这3种功能的组合,又需要更多的子类:

1
2
3
4
5
6
7
8
9
10
11
12
13
                          ┌─────────────────┐
│ FileInputStream │
└─────────────────┘

┌───────────┬─────────┼─────────┬───────────┐
│ │ │ │ │
┌───────────────────────┐│┌─────────────────┐│┌─────────────────────┐
│BufferedFileInputStream│││DigestInputStream│││CipherFileInputStream│
└───────────────────────┘│└─────────────────┘│└─────────────────────┘
│ │
┌─────────────────────────────┐ ┌─────────────────────────────┐
│BufferedDigestFileInputStream│ │BufferedCipherFileInputStream│
└─────────────────────────────┘ └─────────────────────────────┘

这还只是针对FileInputStream设计,如果针对另一种InputStream设计,很快会出现子类爆炸的情况。

因此,直接使用继承,为各种InputStream附加更多的功能,根本无法控制代码的复杂度,很快就会失控。

为了解决依赖继承会导致子类数量失控的问题,JDK首先将InputStream分为两大类:

一类是直接提供数据的基础InputStream,例如:

  • FileInputStream
  • ByteArrayInputStream
  • ServletInputStream

一类是提供额外附加功能的InputStream,例如:

  • BufferedInputStream
  • DigestInputStream
  • CipherInputStream

当我们需要给一个“基础”InputStream附加各种功能时,我们先确定这个能提供数据源的InputStream,因为我们需要的数据总得来自某个地方,例如,FileInputStream,数据来源自文件:

1
InputStream file = new FileInputStream("test.gz");

紧接着,我们希望FileInputStream能提供缓冲的功能来提高读取的效率,因此我们用BufferedInputStream包装这个InputStream,得到的包装类型是BufferedInputStream,但它仍然被视为一个InputStream

1
InputStream buffered = new BufferedInputStream(file);

最后,假设该文件已经用gzip压缩了,我们希望直接读取解压缩的内容,就可以再包装一个GZIPInputStream

1
InputStream gzip = new GZIPInputStream(buffered);

无论我们包装多少次,得到的对象始终是InputStream,我们直接用InputStream来引用它,就可以正常读取:

1
2
3
4
5
6
7
8
9
┌─────────────────────────┐
│GZIPInputStream │
│┌───────────────────────┐│
││BufferedFileInputStream││
││┌─────────────────────┐││
│││ FileInputStream │││
││└─────────────────────┘││
│└───────────────────────┘│
└─────────────────────────┘

上述这种通过一个“基础”组件再叠加各种“附加”功能组件的模式,称之为Filter模式(或者装饰器模式:Decorator)。它可以让我们通过少量的类来实现各种功能的组合:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
                 ┌─────────────┐
│ InputStream │
└─────────────┘
▲ ▲
┌────────────────────┐ │ │ ┌─────────────────┐
│ FileInputStream │─┤ └─│FilterInputStream│
└────────────────────┘ │ └─────────────────┘
┌────────────────────┐ │ ▲ ┌───────────────────┐
│ByteArrayInputStream│─┤ ├─│BufferedInputStream│
└────────────────────┘ │ │ └───────────────────┘
┌────────────────────┐ │ │ ┌───────────────────┐
│ ServletInputStream │─┘ ├─│ DataInputStream │
└────────────────────┘ │ └───────────────────┘
│ ┌───────────────────┐
└─│CheckedInputStream │
└───────────────────┘

类似的,OutputStream也是以这种模式来提供各种功能:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
                  ┌─────────────┐
│OutputStream │
└─────────────┘
▲ ▲
┌─────────────────────┐ │ │ ┌──────────────────┐
│ FileOutputStream │─┤ └─│FilterOutputStream│
└─────────────────────┘ │ └──────────────────┘
┌─────────────────────┐ │ ▲ ┌────────────────────┐
│ByteArrayOutputStream│─┤ ├─│BufferedOutputStream│
└─────────────────────┘ │ │ └────────────────────┘
┌─────────────────────┐ │ │ ┌────────────────────┐
│ ServletOutputStream │─┘ ├─│ DataOutputStream │
└─────────────────────┘ │ └────────────────────┘
│ ┌────────────────────┐
└─│CheckedOutputStream │
└────────────────────┘

编写FilterInputStream

我们也可以自己编写FilterInputStream,以便可以把自己的FilterInputStream“叠加”到任何一个InputStream中。

下面的例子演示了如何编写一个CountInputStream,它的作用是对输入的字节进行计数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import java.io.*;

public class Main {
public static void main(String[] args) throws IOException {
byte[] data = "hello, world!".getBytes("UTF-8");
try (CountInputStream input = new CountInputStream(new ByteArrayInputStream(data))) {
int n;
while ((n = input.read()) != -1) {
System.out.println((char)n);
}
System.out.println("Total read " + input.getBytesRead() + " bytes");
}
}
}

class CountInputStream extends FilterInputStream {
private int count = 0;

CountInputStream(InputStream in) {
super(in);
}

public int getBytesRead() {
return this.count;
}

public int read() throws IOException {
int n = in.read();
if (n != -1) {
this.count ++;
}
return n;
}

public int read(byte[] b, int off, int len) throws IOException {
int n = in.read(b, off, len);
if (n != -1) {
this.count += n;
}
return n;
}
}

注意到在叠加多个FilterInputStream,我们只需要持有最外层的InputStream,并且,当最外层的InputStream关闭时(在try(resource)块的结束处自动关闭),内层的InputStreamclose()方法也会被自动调用,并最终调用到最核心的“基础”InputStream,因此不存在资源泄露。

小结

Java的IO标准库使用Filter模式为InputStreamOutputStream增加功能:

  • 可以把一个InputStream和任意个FilterInputStream组合;
  • 可以把一个OutputStream和任意个FilterOutputStream组合。

Filter模式可以在运行期动态增加功能(又称Decorator模式)。

操作Zip

ZipInputStream是一种FilterInputStream,它可以直接读取zip包的内容:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
┌───────────────────┐
│ InputStream │
└───────────────────┘


┌───────────────────┐
│ FilterInputStream │
└───────────────────┘


┌───────────────────┐
│InflaterInputStream│
└───────────────────┘


┌───────────────────┐
│ ZipInputStream │
└───────────────────┘


┌───────────────────┐
│ JarInputStream │
└───────────────────┘

另一个JarInputStream是从ZipInputStream派生,它增加的主要功能是直接读取jar文件里面的MANIFEST.MF文件。因为本质上jar包就是zip包,只是额外附加了一些固定的描述文件。

读取zip包

我们来看看ZipInputStream的基本用法。

我们要创建一个ZipInputStream,通常是传入一个FileInputStream作为数据源,然后,循环调用getNextEntry(),直到返回null,表示zip流结束。

一个ZipEntry表示一个压缩文件或目录,如果是压缩文件,我们就用read()方法不断读取,直到返回-1

1
2
3
4
5
6
7
8
9
10
11
12
try (ZipInputStream zip = new ZipInputStream(new FileInputStream(...))) {
ZipEntry entry = null;
while ((entry = zip.getNextEntry()) != null) {
String name = entry.getName();
if (!entry.isDirectory()) {
int n;
while ((n = zip.read()) != -1) {
...
}
}
}
}

写入zip包

ZipOutputStream是一种FilterOutputStream,它可以直接写入内容到zip包。我们要先创建一个ZipOutputStream,通常是包装一个FileOutputStream,然后,每写入一个文件前,先调用putNextEntry(),然后用write()写入byte[]数据,写入完毕后调用closeEntry()结束这个文件的打包。

1
2
3
4
5
6
7
8
try (ZipOutputStream zip = new ZipOutputStream(new FileOutputStream(...))) {
File[] files = ...
for (File file : files) {
zip.putNextEntry(new ZipEntry(file.getName()));
zip.write(Files.readAllBytes(file.toPath()));
zip.closeEntry();
}
}

上面的代码没有考虑文件的目录结构。如果要实现目录层次结构,new ZipEntry(name)传入的name要用相对路径。

小结

ZipInputStream可以读取zip格式的流,ZipOutputStream可以把多份数据写入zip包;

配合FileInputStreamFileOutputStream就可以读写zip文件。



读取classpath资源

很多Java程序启动的时候,都需要读取配置文件。例如,从一个.properties文件中读取配置:

1
2
3
4
String conf = "C:\\conf\\default.properties";
try (InputStream input = new FileInputStream(conf)) {
// TODO:
}

这段代码要正常执行,必须在C盘创建conf目录,然后在目录里创建default.properties文件。但是,在Linux系统上,路径和Windows的又不一样。

因此,从磁盘的固定目录读取配置文件,不是一个好的办法。

有没有路径无关的读取文件的方式呢?

我们知道,Java存放.class的目录或jar包也可以包含任意其他类型的文件,例如:

  • 配置文件,例如.properties
  • 图片文件,例如.jpg
  • 文本文件,例如.txt.csv
  • ……

从classpath读取文件就可以避免不同环境下文件路径不一致的问题:如果我们把default.properties文件放到classpath中,就不用关心它的实际存放路径。

在classpath中的资源文件,路径总是以开头,我们先获取当前的Class对象,然后调用getResourceAsStream()就可以直接从classpath读取任意的资源文件:

1
2
3
try (InputStream input = getClass().getResourceAsStream("/default.properties")) {
// TODO:
}

调用getResourceAsStream()需要特别注意的一点是,如果资源文件不存在,它将返回null。因此,我们需要检查返回的InputStream是否为null,如果为null,表示资源文件在classpath中没有找到:

1
2
3
4
5
try (InputStream input = getClass().getResourceAsStream("/default.properties")) {
if (input != null) {
// TODO:
}
}

如果我们把默认的配置放到jar包中,再从外部文件系统读取一个可选的配置文件,就可以做到既有默认的配置文件,又可以让用户自己修改配置:

1
2
3
Properties props = new Properties();
props.load(inputStreamFromClassPath("/default.properties"));
props.load(inputStreamFromFile("./conf.properties"));

这样读取配置文件,应用程序启动就更加灵活。

小结

把资源存储在classpath中可以避免文件路径依赖;

Class对象的getResourceAsStream()可以从classpath中读取指定资源;

根据classpath读取资源时,需要检查返回的InputStream是否为null



序列化是指把一个Java对象变成二进制内容,本质上就是一个byte[]数组。

为什么要把Java对象序列化呢?因为序列化后可以把byte[]保存到文件中,或者把byte[]通过网络传输到远程,这样,就相当于把Java对象存储到文件或者通过网络传输出去了。

有序列化,就有反序列化,即把一个二进制内容(也就是byte[]数组)变回Java对象。有了反序列化,保存到文件中的byte[]数组又可以“变回”Java对象,或者从网络上读取byte[]并把它“变回”Java对象。

我们来看看如何把一个Java对象序列化。

一个Java对象要能序列化,必须实现一个特殊的java.io.Serializable接口,它的定义如下:

1
2
public interface Serializable {
}

Serializable接口没有定义任何方法,它是一个空接口。我们把这样的空接口称为“标记接口”(Marker Interface),实现了标记接口的类仅仅是给自身贴了个“标记”,并没有增加任何方法。

序列化

把一个Java对象变为byte[]数组,需要使用ObjectOutputStream。它负责把一个Java对象写入一个字节流:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import java.io.*;
import java.util.Arrays;

public class Main {
public static void main(String[] args) throws IOException {
ByteArrayOutputStream buffer = new ByteArrayOutputStream();
try (ObjectOutputStream output = new ObjectOutputStream(buffer)) {
// 写入int:
output.writeInt(12345);
// 写入String:
output.writeUTF("Hello");
// 写入Object:
output.writeObject(Double.valueOf(123.456));
}
System.out.println(Arrays.toString(buffer.toByteArray()));
}
}

ObjectOutputStream既可以写入基本类型,如intboolean,也可以写入String(以UTF-8编码),还可以写入实现了Serializable接口的Object

因为写入Object时需要大量的类型信息,所以写入的内容很大。

反序列化

ObjectOutputStream相反,ObjectInputStream负责从一个字节流读取Java对象:

1
2
3
4
5
try (ObjectInputStream input = new ObjectInputStream(...)) {
int n = input.readInt();
String s = input.readUTF();
Double d = (Double) input.readObject();
}

除了能读取基本类型和String类型外,调用readObject()可以直接返回一个Object对象。要把它变成一个特定类型,必须强制转型。

readObject()可能抛出的异常有:

  • ClassNotFoundException:没有找到对应的Class;
  • InvalidClassException:Class不匹配。

对于ClassNotFoundException,这种情况常见于一台电脑上的Java程序把一个Java对象,例如,Person对象序列化以后,通过网络传给另一台电脑上的另一个Java程序,但是这台电脑的Java程序并没有定义Person类,所以无法反序列化。

对于InvalidClassException,这种情况常见于序列化的Person对象定义了一个int类型的age字段,但是反序列化时,Person类定义的age字段被改成了long类型,所以导致class不兼容。

为了避免这种class定义变动导致的不兼容,Java的序列化允许class定义一个特殊的serialVersionUID静态变量,用于标识Java类的序列化“版本”,通常可以由IDE自动生成。如果增加或修改了字段,可以改变serialVersionUID的值,这样就能自动阻止不匹配的class版本:

1
2
3
public class Person implements Serializable {
private static final long serialVersionUID = 2709425275741743919L;
}

要特别注意反序列化的几个重要特点:

反序列化时,由JVM直接构造出Java对象,不调用构造方法,构造方法内部的代码,在反序列化时根本不可能执行。

安全性

因为Java的序列化机制可以导致一个实例能直接从byte[]数组创建,而不经过构造方法,因此,它存在一定的安全隐患。一个精心构造的byte[]数组被反序列化后可以执行特定的Java代码,从而导致严重的安全漏洞。

实际上,Java本身提供的基于对象的序列化和反序列化机制既存在安全性问题,也存在兼容性问题。更好的序列化方法是通过JSON这样的通用数据结构来实现,只输出基本类型(包括String)的内容,而不存储任何与代码相关的信息。

小结

可序列化的Java对象必须实现java.io.Serializable接口,类似Serializable这样的空接口被称为“标记接口”(Marker Interface);

反序列化时不调用构造方法,可设置serialVersionUID作为版本号(非必需);

Java的序列化机制仅适用于Java,如果需要与其它语言交换数据,必须使用通用的序列化方法,例如JSON。

Reader是Java的IO库提供的另一个输入流接口。和InputStream的区别是,InputStream是一个字节流,即以byte为单位读取,而Reader是一个字符流,即以char为单位读取:

InputStream Reader
字节流,以byte为单位 字符流,以char为单位
读取字节(-1,0~255):int read() 读取字符(-1,0~65535):int read()
读到字节数组:int read(byte[] b) 读到字符数组:int read(char[] c)

java.io.Reader是所有字符输入流的超类,它最主要的方法是:

1
public int read() throws IOException;

这个方法读取字符流的下一个字符,并返回字符表示的int,范围是0~65535。如果已读到末尾,返回-1

FileReader

FileReaderReader的一个子类,它可以打开文件并获取Reader。下面的代码演示了如何完整地读取一个FileReader的所有字符:

1
2
3
4
5
6
7
8
9
10
11
12
public void readFile() throws IOException {
// 创建一个FileReader对象:
Reader reader = new FileReader("src/readme.txt"); // 字符编码是???
for (;;) {
int n = reader.read(); // 反复调用read()方法,直到返回-1
if (n == -1) {
break;
}
System.out.println((char)n); // 打印char
}
reader.close(); // 关闭流
}

如果我们读取一个纯ASCII编码的文本文件,上述代码工作是没有问题的。但如果文件中包含中文,就会出现乱码,因为FileReader默认的编码与系统相关,例如,Windows系统的默认编码可能是GBK,打开一个UTF-8编码的文本文件就会出现乱码。

要避免乱码问题,我们需要在创建FileReader时指定编码:

1
Reader reader = new FileReader("src/readme.txt", StandardCharsets.UTF_8);

InputStream类似,Reader也是一种资源,需要保证出错的时候也能正确关闭,所以我们需要用try (resource)来保证Reader在无论有没有IO错误的时候都能够正确地关闭:

1
2
3
try (Reader reader = new FileReader("src/readme.txt", StandardCharsets.UTF_8)) {
// TODO
}

Reader还提供了一次性读取若干字符并填充到char[]数组的方法:

1
public int read(char[] c) throws IOException

它返回实际读入的字符个数,最大不超过char[]数组的长度。返回-1表示流结束。

利用这个方法,我们可以先设置一个缓冲区,然后,每次尽可能地填充缓冲区:

1
2
3
4
5
6
7
8
9
public void readFile() throws IOException {
try (Reader reader = new FileReader("src/readme.txt", StandardCharsets.UTF_8)) {
char[] buffer = new char[1000];
int n;
while ((n = reader.read(buffer)) != -1) {
System.out.println("read " + n + " chars.");
}
}
}

CharArrayReader

CharArrayReader可以在内存中模拟一个Reader,它的作用实际上是把一个char[]数组变成一个Reader,这和ByteArrayInputStream非常类似:

1
2
try (Reader reader = new CharArrayReader("Hello".toCharArray())) {
}

StringReader

StringReader可以直接把String作为数据源,它和CharArrayReader几乎一样:

1
2
try (Reader reader = new StringReader("Hello")) {
}

InputStreamReader

ReaderInputStream有什么关系?

除了特殊的CharArrayReaderStringReader,普通的Reader实际上是基于InputStream构造的,因为Reader需要从InputStream中读入字节流(byte),然后,根据编码设置,再转换为char就可以实现字符流。如果我们查看FileReader的源码,它在内部实际上持有一个FileInputStream

既然Reader本质上是一个基于InputStreambytechar的转换器,那么,如果我们已经有一个InputStream,想把它转换为Reader,是完全可行的。InputStreamReader就是这样一个转换器,它可以把任何InputStream转换为Reader。示例代码如下:

1
2
3
4
// 持有InputStream:
InputStream input = new FileInputStream("src/readme.txt");
// 变换为Reader:
Reader reader = new InputStreamReader(input, "UTF-8");

构造InputStreamReader时,我们需要传入InputStream,还需要指定编码,就可以得到一个Reader对象。上述代码可以通过try (resource)更简洁地改写如下:

1
2
3
try (Reader reader = new InputStreamReader(new FileInputStream("src/readme.txt"), "UTF-8")) {
// TODO:
}

上述代码实际上就是FileReader的一种实现方式。

使用try (resource)结构时,当我们关闭Reader时,它会在内部自动调用InputStreamclose()方法,所以,只需要关闭最外层的Reader对象即可。

提示

使用InputStreamReader,可以把一个InputStream转换成一个Reader。

小结

Reader定义了所有字符输入流的超类:

  • FileReader实现了文件字符流输入,使用时需要指定编码;
  • CharArrayReaderStringReader可以在内存中模拟一个字符流输入。

Reader是基于InputStream构造的:可以通过InputStreamReader在指定编码的同时将任何InputStream转换为Reader

总是使用try (resource)保证Reader正确关闭。

Writer

Reader是带编码转换器的InputStream,它把byte转换为char,而Writer就是带编码转换器的OutputStream,它把char转换为byte并输出。

WriterOutputStream的区别如下:

OutputStream Writer
字节流,以byte为单位 字符流,以char为单位
写入字节(0~255):void write(int b) 写入字符(0~65535):void write(int c)
写入字节数组:void write(byte[] b) 写入字符数组:void write(char[] c)
无对应方法 写入String:void write(String s)

Writer是所有字符输出流的超类,它提供的方法主要有:

  • 写入一个字符(0~65535):void write(int c)
  • 写入字符数组的所有字符:void write(char[] c)
  • 写入String表示的所有字符:void write(String s)

FileWriter

FileWriter就是向文件中写入字符流的Writer。它的使用方法和FileReader类似:

1
2
3
4
5
try (Writer writer = new FileWriter("readme.txt", StandardCharsets.UTF_8)) {
writer.write('H'); // 写入单个字符
writer.write("Hello".toCharArray()); // 写入char[]
writer.write("Hello"); // 写入String
}

CharArrayWriter

CharArrayWriter可以在内存中创建一个Writer,它的作用实际上是构造一个缓冲区,可以写入char,最后得到写入的char[]数组,这和ByteArrayOutputStream非常类似:

1
2
3
4
5
6
try (CharArrayWriter writer = new CharArrayWriter()) {
writer.write(65);
writer.write(66);
writer.write(67);
char[] data = writer.toCharArray(); // { 'A', 'B', 'C' }
}

StringWriter

StringWriter也是一个基于内存的Writer,它和CharArrayWriter类似。实际上,StringWriter在内部维护了一个StringBuffer,并对外提供了Writer接口。

OutputStreamWriter

除了CharArrayWriterStringWriter外,普通的Writer实际上是基于OutputStream构造的,它接收char,然后在内部自动转换成一个或多个byte,并写入OutputStream。因此,OutputStreamWriter就是一个将任意的OutputStream转换为Writer的转换器:

1
2
3
try (Writer writer = new OutputStreamWriter(new FileOutputStream("readme.txt"), "UTF-8")) {
// TODO:
}

上述代码实际上就是FileWriter的一种实现方式。这和上一节的InputStreamReader是一样的。

小结

Writer定义了所有字符输出流的超类:

  • FileWriter实现了文件字符流输出;
  • CharArrayWriterStringWriter在内存中模拟一个字符流输出。

使用try (resource)保证Writer正确关闭;

Writer是基于OutputStream构造的,可以通过OutputStreamWriterOutputStream转换为Writer,转换时需要指定编码。



PrintStream和PrintWriter

PrintStream是一种FilterOutputStream,它在OutputStream的接口上,额外提供了一些写入各种数据类型的方法:

  • 写入intprint(int)
  • 写入booleanprint(boolean)
  • 写入Stringprint(String)
  • 写入Objectprint(Object),实际上相当于print(object.toString())

以及对应的一组println()方法,它会自动加上换行符。

我们经常使用的System.out.println()实际上就是使用PrintStream打印各种数据。其中,System.out是系统默认提供的PrintStream,表示标准输出:

1
2
3
System.out.print(12345); // 输出12345
System.out.print(new Object()); // 输出类似java.lang.Object@3c7a835a
System.out.println("Hello"); // 输出Hello并换行

System.err是系统默认提供的标准错误输出。

PrintStreamOutputStream相比,除了添加了一组print()/println()方法,可以打印各种数据类型,比较方便外,它还有一个额外的优点,就是不会抛出IOException,这样我们在编写代码的时候,就不必捕获IOException

PrintWriter

PrintStream最终输出的总是byte数据,而PrintWriter则是扩展了Writer接口,它的print()/println()方法最终输出的是char数据。两者的使用方法几乎是一模一样的:

1
2
3
4
5
6
7
8
9
10
11
12
13
import java.io.*;

public class Main {
public static void main(String[] args) {
StringWriter buffer = new StringWriter();
try (PrintWriter pw = new PrintWriter(buffer)) {
pw.println("Hello");
pw.println(12345);
pw.println(true);
}
System.out.println(buffer.toString());
}
}

小结

PrintStream是一种能接收各种数据类型的输出,打印数据时比较方便:

  • System.out是标准输出;
  • System.err是标准错误输出。

PrintWriter是基于Writer的输出。



使用Files

从Java 7开始,提供了Files这个工具类,能极大地方便我们读写文件。

虽然Filesjava.nio包里面的类,但他俩封装了很多读写文件的简单方法,例如,我们要把一个文件的全部内容读取为一个byte[],可以这么写:

1
byte[] data = Files.readAllBytes(Path.of("/path/to/file.txt"));

如果是文本文件,可以把一个文件的全部内容读取为String

1
2
3
4
5
6
// 默认使用UTF-8编码读取:
String content1 = Files.readString(Path.of("/path/to/file.txt"));
// 可指定编码:
String content2 = Files.readString(Path.of("/path", "to", "file.txt"), StandardCharsets.ISO_8859_1);
// 按行读取并返回每行内容:
List<String> lines = Files.readAllLines(Path.of("/path/to/file.txt"));

写入文件也非常方便:

1
2
3
4
5
6
7
8
// 写入二进制文件:
byte[] data = ...
Files.write(Path.of("/path/to/file.txt"), data);
// 写入文本并指定编码:
Files.writeString(Path.of("/path/to/file.txt"), "文本内容...", StandardCharsets.ISO_8859_1);
// 按行写入文本:
List<String> lines = ...
Files.write(Path.of("/path/to/file.txt"), lines);

此外,Files工具类还有copy()delete()exists()move()等快捷方法操作文件和目录。

最后需要特别注意的是,Files提供的读写方法,受内存限制,只能读写小文件,例如配置文件等,不可一次读入几个G的大文件。读写大型文件仍然要使用文件流,每次只读写一部分文件内容。

小结

对于简单的小文件读写操作,可以使用Files工具类简化代码。



留言與分享

JAVA-集合

分類 编程语言, Java

Java集合简介

什么是集合(Collection)?集合就是“由若干个确定的元素所构成的整体”。例如,5只小兔构成的集合:

1
2
3
4
5
6
7
┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐

│ (\_(\ (\_/) (\_/) (\_/) (\(\ │
( -.-) (•.•) (>.<) (^.^) (='.')
│ C(")_(") (")_(") (")_(") (")_(") O(_")") │

└ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┘

在数学中,我们经常遇到集合的概念。例如:

  • 有限集合:
    • 一个班所有的同学构成的集合;
    • 一个网站所有的商品构成的集合;
  • 无限集合:
    • 全体自然数集合:1,2,3,……
    • 有理数集合;
    • 实数集合;

为什么要在计算机中引入集合呢?这是为了便于处理一组类似的数据,例如:

  • 计算所有同学的总成绩和平均成绩;
  • 列举所有的商品名称和价格;
  • ……

在Java中,如果一个Java对象可以在内部持有若干其他Java对象,并对外提供访问接口,我们把这种Java对象称为集合。很显然,Java的数组可以看作是一种集合:

1
2
3
String[] ss = new String[10]; // 可以持有10个String对象
ss[0] = "Hello"; // 可以放入String对象
String first = ss[0]; // 可以获取String对象

既然Java提供了数组这种数据类型,可以充当集合,那么,我们为什么还需要其他集合类?这是因为数组有如下限制:

  • 数组初始化后大小不可变;
  • 数组只能按索引顺序存取。

因此,我们需要各种不同类型的集合类来处理不同的数据,例如:

  • 可变大小的顺序链表;
  • 保证无重复元素的集合;

Collection

Java标准库自带的java.util包提供了集合类:Collection,它是除Map外所有其他集合类的根接口。Java的java.util包主要提供了以下三种类型的集合:

  • List:一种有序列表的集合,例如,按索引排列的StudentList
  • Set:一种保证没有重复元素的集合,例如,所有无重复名称的StudentSet
  • Map:一种通过键值(key-value)查找的映射表集合,例如,根据Studentname查找对应StudentMap

Java集合的设计有几个特点:一是实现了接口和实现类相分离,例如,有序表的接口是List,具体的实现类有ArrayListLinkedList等,二是支持泛型,我们可以限制在一个集合中只能放入同一种数据类型的元素,例如:

1
List<String> list = new ArrayList<>(); // 只能放入String类型

最后,Java访问集合总是通过统一的方式——迭代器(Iterator)来实现,它最明显的好处在于无需知道集合内部元素是按什么方式存储的。

由于Java的集合设计非常久远,中间经历过大规模改进,我们要注意到有一小部分集合类是遗留类,不应该继续使用:

  • Hashtable:一种线程安全的Map实现;
  • Vector:一种线程安全的List实现;
  • Stack:基于Vector实现的LIFO的栈。

还有一小部分接口是遗留接口,也不应该继续使用:

  • Enumeration<E>:已被Iterator<E>取代。

小结

Java的集合类定义在java.util包中,支持泛型,主要提供了3种集合类,包括ListSetMap。Java集合使用统一的Iterator遍历,尽量不要使用遗留接口。



在集合类中,List是最基础的一种集合:它是一种有序列表。

List的行为和数组几乎完全相同:List内部按照放入元素的先后顺序存放,每个元素都可以通过索引确定自己的位置,List的索引和数组一样,从0开始。

数组和List类似,也是有序结构,如果我们使用数组,在添加和删除元素的时候,会非常不方便。例如,从一个已有的数组{'A', 'B', 'C', 'D', 'E'}中删除索引为2的元素:

1
2
3
4
5
6
7
8
9
10
11
┌───┬───┬───┬───┬───┬───┐
│ A │ B │ C │ D │ E │ │
└───┴───┴───┴───┴───┴───┘
│ │
┌───┘ │
│ ┌───┘
│ │
▼ ▼
┌───┬───┬───┬───┬───┬───┐
│ A │ B │ D │ E │ │ │
└───┴───┴───┴───┴───┴───┘

这个“删除”操作实际上是把'C'后面的元素依次往前挪一个位置,而“添加”操作实际上是把指定位置以后的元素都依次向后挪一个位置,腾出来的位置给新加的元素。这两种操作,用数组实现非常麻烦。

因此,在实际应用中,需要增删元素的有序列表,我们使用最多的是ArrayList。实际上,ArrayList在内部使用了数组来存储所有元素。例如,一个ArrayList拥有5个元素,实际数组大小为6(即有一个空位):

1
2
3
4
size=5
┌───┬───┬───┬───┬───┬───┐
│ A │ B │ C │ D │ E │ │
└───┴───┴───┴───┴───┴───┘

当添加一个元素并指定索引到ArrayList时,ArrayList自动移动需要移动的元素:

1
2
3
4
size=5
┌───┬───┬───┬───┬───┬───┐
│ A │ B │ │ C │ D │ E │
└───┴───┴───┴───┴───┴───┘

然后,往内部指定索引的数组位置添加一个元素,然后把size1

1
2
3
4
size=6
┌───┬───┬───┬───┬───┬───┐
│ A │ B │ F │ C │ D │ E │
└───┴───┴───┴───┴───┴───┘

继续添加元素,但是数组已满,没有空闲位置的时候,ArrayList先创建一个更大的新数组,然后把旧数组的所有元素复制到新数组,紧接着用新数组取代旧数组:

1
2
3
4
size=6
┌───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┐
│ A │ B │ F │ C │ D │ E │ │ │ │ │ │ │
└───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┘

现在,新数组就有了空位,可以继续添加一个元素到数组末尾,同时size1

1
2
3
4
size=7
┌───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┐
│ A │ B │ F │ C │ D │ E │ G │ │ │ │ │ │
└───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┘

可见,ArrayList把添加和删除的操作封装起来,让我们操作List类似于操作数组,却不用关心内部元素如何移动。

我们考察List<E>接口,可以看到几个主要的接口方法:

  • 在末尾添加一个元素:boolean add(E e)
  • 在指定索引添加一个元素:boolean add(int index, E e)
  • 删除指定索引的元素:E remove(int index)
  • 删除某个元素:boolean remove(Object e)
  • 获取指定索引的元素:E get(int index)
  • 获取链表大小(包含元素的个数):int size()

但是,实现List接口并非只能通过数组(即ArrayList的实现方式)来实现,另一种LinkedList通过“链表”也实现了List接口。在LinkedList中,它的内部每个元素都指向下一个元素:

1
2
3
        ┌───┬───┐   ┌───┬───┐   ┌───┬───┐   ┌───┬───┐
HEAD ──▶│ A │ ●─┼──▶│ B │ ●─┼──▶│ C │ ●─┼──▶│ D │ │
└───┴───┘ └───┴───┘ └───┴───┘ └───┴───┘

我们来比较一下ArrayListLinkedList

ArrayList LinkedList
获取指定元素 速度很快 需要从头开始查找元素
添加元素到末尾 速度很快 速度很快
在指定位置添加/删除 需要移动元素 不需要移动元素
内存占用 较大

通常情况下,我们总是优先使用ArrayList

List的特点

使用List时,我们要关注List接口的规范。List接口允许我们添加重复的元素,即List内部的元素可以重复:

1
2
3
4
5
6
7
8
9
10
11
12
import java.util.ArrayList;
import java.util.List;

public class Main {
public static void main(String[] args) {
List<String> list = new ArrayList<>();
list.add("apple"); // size=1
list.add("pear"); // size=2
list.add("apple"); // 允许重复添加元素,size=3
System.out.println(list.size());
}
}

List还允许添加null

1
2
3
4
5
6
7
8
9
10
11
12
13
import java.util.ArrayList;
import java.util.List;

public class Main {
public static void main(String[] args) {
List<String> list = new ArrayList<>();
list.add("apple"); // size=1
list.add(null); // size=2
list.add("pear"); // size=3
String second = list.get(1); // null
System.out.println(second);
}
}

创建List

除了使用ArrayListLinkedList,我们还可以通过List接口提供的of()方法,根据给定元素快速创建List

1
List<Integer> list = List.of(1, 2, 5);

但是List.of()方法不接受null值,如果传入null,会抛出NullPointerException异常。

遍历List

和数组类型类似,我们要遍历一个List,完全可以用for循环根据索引配合get(int)方法遍历:

1
2
3
4
5
6
7
8
9
10
11
import java.util.List;

public class Main {
public static void main(String[] args) {
List<String> list = List.of("apple", "pear", "banana");
for (int i=0; i<list.size(); i++) {
String s = list.get(i);
System.out.println(s);
}
}
}

但这种方式并不推荐,一是代码复杂,二是因为get(int)方法只有ArrayList的实现是高效的,换成LinkedList后,索引越大,访问速度越慢。

所以我们要始终坚持使用迭代器Iterator来访问ListIterator本身也是一个对象,但它是由List的实例调用iterator()方法的时候创建的。Iterator对象知道如何遍历一个List,并且不同的List类型,返回的Iterator对象实现也是不同的,但总是具有最高的访问效率。

Iterator对象有两个方法:boolean hasNext()判断是否有下一个元素,E next()返回下一个元素。因此,使用Iterator遍历List代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
import java.util.Iterator;
import java.util.List;

public class Main {
public static void main(String[] args) {
List<String> list = List.of("apple", "pear", "banana");
for (Iterator<String> it = list.iterator(); it.hasNext(); ) {
String s = it.next();
System.out.println(s);
}
}
}

有童鞋可能觉得使用Iterator访问List的代码比使用索引更复杂。但是,要记住,通过Iterator遍历List永远是最高效的方式。并且,由于Iterator遍历是如此常用,所以,Java的for each循环本身就可以帮我们使用Iterator遍历。把上面的代码再改写如下:

1
2
3
4
5
6
7
8
9
10
import java.util.List;

public class Main {
public static void main(String[] args) {
List<String> list = List.of("apple", "pear", "banana");
for (String s : list) {
System.out.println(s);
}
}
}

上述代码就是我们编写遍历List的常见代码。

实际上,只要实现了Iterable接口的集合类都可以直接用for each循环来遍历,Java编译器本身并不知道如何遍历集合对象,但它会自动把for each循环变成Iterator的调用,原因就在于Iterable接口定义了一个Iterator<E> iterator()方法,强迫集合类必须返回一个Iterator实例。

List和Array转换

List变为Array有三种方法,第一种是调用toArray()方法直接返回一个Object[]数组:

1
2
3
4
5
6
7
8
9
10
11
import java.util.List;

public class Main {
public static void main(String[] args) {
List<String> list = List.of("apple", "pear", "banana");
Object[] array = list.toArray();
for (Object s : array) {
System.out.println(s);
}
}
}

这种方法会丢失类型信息,所以实际应用很少。

第二种方式是给toArray(T[])传入一个类型相同的ArrayList内部自动把元素复制到传入的Array中:

1
2
3
4
5
6
7
8
9
10
11
import java.util.List;

public class Main {
public static void main(String[] args) {
List<Integer> list = List.of(12, 34, 56);
Integer[] array = list.toArray(new Integer[3]);
for (Integer n : array) {
System.out.println(n);
}
}
}

注意到这个toArray(T[])方法的泛型参数<T>并不是List接口定义的泛型参数<E>,所以,我们实际上可以传入其他类型的数组,例如我们传入Number类型的数组,返回的仍然是Number类型:

1
2
3
4
5
6
7
8
9
10
11
import java.util.List;

public class Main {
public static void main(String[] args) {
List<Integer> list = List.of(12, 34, 56);
Number[] array = list.toArray(new Number[3]);
for (Number n : array) {
System.out.println(n);
}
}
}

但是,如果我们传入类型不匹配的数组,例如,String[]类型的数组,由于List的元素是Integer,所以无法放入String数组,这个方法会抛出ArrayStoreException

如果我们传入的数组大小和List实际的元素个数不一致怎么办?根据List接口的文档,我们可以知道:

如果传入的数组不够大,那么List内部会创建一个新的刚好够大的数组,填充后返回;如果传入的数组比List元素还要多,那么填充完元素后,剩下的数组元素一律填充null

实际上,最常用的是传入一个“恰好”大小的数组:

1
Integer[] array = list.toArray(new Integer[list.size()]);

最后一种更简洁的写法是通过List接口定义的T[] toArray(IntFunction<T[]> generator)方法:

1
Integer[] array = list.toArray(Integer[]::new);

这种函数式写法我们会在后续讲到。

反过来,把Array变为List就简单多了,通过List.of(T...)方法最简单:

1
2
Integer[] array = { 1, 2, 3 };
List<Integer> list = List.of(array);

对于JDK 11之前的版本,可以使用Arrays.asList(T...)方法把数组转换成List

要注意的是,返回的List不一定就是ArrayList或者LinkedList,因为List只是一个接口,如果我们调用List.of(),它返回的是一个只读List

1
2
3
4
5
6
7
8
import java.util.List;

public class Main {
public static void main(String[] args) {
List<Integer> list = List.of(12, 34, 56);
list.add(999); // UnsupportedOperationException
}
}

对只读List调用add()remove()方法会抛出UnsupportedOperationException

练习

给定一组连续的整数,例如:10,11,12,……,20,但其中缺失一个数字,试找出缺失的数字:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import java.util.*;

public class Main {
public static void main(String[] args) {
// 构造从start到end的序列:
final int start = 10;
final int end = 20;
List<Integer> list = new ArrayList<>();
for (int i = start; i <= end; i++) {
list.add(i);
}
// 随机删除List中的一个元素:
int removed = list.remove((int) (Math.random() * list.size()));
int found = findMissingNumber(start, end, list);
System.out.println(list.toString());
System.out.println("missing number: " + found);
System.out.println(removed == found ? "测试成功" : "测试失败");
}

static int findMissingNumber(int start, int end, List<Integer> list) {
return 0;
}
}

增强版:和上述题目一样,但整数不再有序,试找出缺失的数字:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import java.util.*;

public class Main {
public static void main(String[] args) {
// 构造从start到end的序列:
final int start = 10;
final int end = 20;
List<Integer> list = new ArrayList<>();
for (int i = start; i <= end; i++) {
list.add(i);
}
// 洗牌算法shuffle可以随机交换List中的元素位置:
Collections.shuffle(list);
// 随机删除List中的一个元素:
int removed = list.remove((int) (Math.random() * list.size()));
int found = findMissingNumber(start, end, list);
System.out.println(list.toString());
System.out.println("missing number: " + found);
System.out.println(removed == found ? "测试成功" : "测试失败");
}

static int findMissingNumber(int start, int end, List<Integer> list) {
return 0;
}
}

下载练习

小结

List是按索引顺序访问的长度可变的有序表,优先使用ArrayList而不是LinkedList

可以直接使用for each遍历List

List可以和Array相互转换。

我们知道List是一种有序链表:List内部按照放入元素的先后顺序存放,并且每个元素都可以通过索引确定自己的位置。

List还提供了boolean contains(Object o)方法来判断List是否包含某个指定元素。此外,int indexOf(Object o)方法可以返回某个元素的索引,如果元素不存在,就返回-1

我们来看一个例子:

1
2
3
4
5
6
7
8
9
10
11
import java.util.List;

public class Main {
public static void main(String[] args) {
List<String> list = List.of("A", "B", "C");
System.out.println(list.contains("C")); // true
System.out.println(list.contains("X")); // false
System.out.println(list.indexOf("C")); // 2
System.out.println(list.indexOf("X")); // -1
}
}

这里我们注意一个问题,我们往List中添加的"C"和调用contains("C")传入的"C"是不是同一个实例?

如果这两个"C"不是同一个实例,这段代码是否还能得到正确的结果?我们可以改写一下代码测试一下:

1
2
3
4
5
6
7
8
9
import java.util.List;

public class Main {
public static void main(String[] args) {
List<String> list = List.of("A", "B", "C");
System.out.println(list.contains(new String("C"))); // true or false?
System.out.println(list.indexOf(new String("C"))); // 2 or -1?
}
}

因为我们传入的是new String("C"),所以一定是不同的实例。结果仍然符合预期,这是为什么呢?

因为List内部并不是通过==判断两个元素是否相等,而是使用equals()方法判断两个元素是否相等,例如contains()方法可以实现如下:

1
2
3
4
5
6
7
8
9
10
11
public class ArrayList {
Object[] elementData;
public boolean contains(Object o) {
for (int i = 0; i < elementData.length; i++) {
if (o.equals(elementData[i])) {
return true;
}
}
return false;
}
}

因此,要正确使用Listcontains()indexOf()这些方法,放入的实例必须正确覆写equals()方法,否则,放进去的实例,查找不到。我们之所以能正常放入StringInteger这些对象,是因为Java标准库定义的这些类已经正确实现了equals()方法。

我们以Person对象为例,测试一下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import java.util.List;

public class Main {
public static void main(String[] args) {
List<Person> list = List.of(
new Person("Xiao Ming"),
new Person("Xiao Hong"),
new Person("Bob")
);
System.out.println(list.contains(new Person("Bob"))); // false
}
}

class Person {
String name;
public Person(String name) {
this.name = name;
}
}

不出意外,虽然放入了new Person("Bob"),但是用另一个new Person("Bob")查询不到,原因就是Person类没有覆写equals()方法。

编写equals

如何正确编写equals()方法?equals()方法要求我们必须满足以下条件:

  • 自反性(Reflexive):对于非nullx来说,x.equals(x)必须返回true
  • 对称性(Symmetric):对于非nullxy来说,如果x.equals(y)true,则y.equals(x)也必须为true
  • 传递性(Transitive):对于非nullxyz来说,如果x.equals(y)truey.equals(z)也为true,那么x.equals(z)也必须为true
  • 一致性(Consistent):对于非nullxy来说,只要xy状态不变,则x.equals(y)总是一致地返回true或者false
  • null的比较:即x.equals(null)永远返回false

上述规则看上去似乎非常复杂,但其实代码实现equals()方法是很简单的,我们以Person类为例:

1
2
3
4
public class Person {
public String name;
public int age;
}

首先,我们要定义“相等”的逻辑含义。对于Person类,如果name相等,并且age相等,我们就认为两个Person实例相等。

因此,编写equals()方法如下:

1
2
3
4
5
6
public boolean equals(Object o) {
if (o instanceof Person p) {
return this.name.equals(p.name) && this.age == p.age;
}
return false;
}

对于引用字段比较,我们使用equals(),对于基本类型字段的比较,我们使用==

如果this.namenull,那么equals()方法会报错,因此,需要继续改写如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
public boolean equals(Object o) {
if (o instanceof Person p) {
boolean nameEquals = false;
if (this.name == null && p.name == null) {
nameEquals = true;
}
if (this.name != null) {
nameEquals = this.name.equals(p.name);
}
return nameEquals && this.age == p.age;
}
return false;
}

如果Person有好几个引用类型的字段,上面的写法就太复杂了。要简化引用类型的比较,我们使用Objects.equals()静态方法:

1
2
3
4
5
6
public boolean equals(Object o) {
if (o instanceof Person p) {
return Objects.equals(this.name, p.name) && this.age == p.age;
}
return false;
}

因此,我们总结一下equals()方法的正确编写方法:

  1. 先确定实例“相等”的逻辑,即哪些字段相等,就认为实例相等;
  2. instanceof判断传入的待比较的Object是不是当前类型,如果是,继续比较,否则,返回false
  3. 对引用类型用Objects.equals()比较,对基本类型直接用==比较。

使用Objects.equals()比较两个引用类型是否相等的目的是省去了判断null的麻烦。两个引用类型都是null时它们也是相等的。

如果不调用Listcontains()indexOf()这些方法,那么放入的元素就不需要实现equals()方法。

练习

给Person类增加equals方法,使得调用indexOf()方法返回正常:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import java.util.List;
import java.util.Objects;

public class Main {
public static void main(String[] args) {
List<Person> list = List.of(
new Person("Xiao", "Ming", 18),
new Person("Xiao", "Hong", 25),
new Person("Bob", "Smith", 20)
);
boolean exist = list.contains(new Person("Bob", "Smith", 20));
System.out.println(exist ? "测试成功!" : "测试失败!");
}
}

class Person {
String firstName;
String lastName;
int age;
public Person(String firstName, String lastName, int age) {
this.firstName = firstName;
this.lastName = lastName;
this.age = age;
}
}

下载练习

小结

List中查找元素时,List的实现类通过元素的equals()方法比较两个元素是否相等,因此,放入的元素必须正确覆写equals()方法,Java标准库提供的StringInteger等已经覆写了equals()方法;

编写equals()方法可借助Objects.equals()判断。

如果不在List中查找元素,就不必覆写equals()方法。

我们知道,List是一种顺序列表,如果有一个存储学生Student实例的List,要在List中根据name查找某个指定的Student的分数,应该怎么办?

最简单的方法是遍历List并判断name是否相等,然后返回指定元素:

1
2
3
4
5
6
7
8
9
List<Student> list = ...
Student target = null;
for (Student s : list) {
if ("Xiao Ming".equals(s.name)) {
target = s;
break;
}
}
System.out.println(target.score);

这种需求其实非常常见,即通过一个键去查询对应的值。使用List来实现存在效率非常低的问题,因为平均需要扫描一半的元素才能确定,而Map这种键值(key-value)映射表的数据结构,作用就是能高效通过key快速查找value(元素)。

Map来实现根据name查询某个Student的代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import java.util.HashMap;
import java.util.Map;

public class Main {
public static void main(String[] args) {
Student s = new Student("Xiao Ming", 99);
Map<String, Student> map = new HashMap<>();
map.put("Xiao Ming", s); // 将"Xiao Ming"和Student实例映射并关联
Student target = map.get("Xiao Ming"); // 通过key查找并返回映射的Student实例
System.out.println(target == s); // true,同一个实例
System.out.println(target.score); // 99
Student another = map.get("Bob"); // 通过另一个key查找
System.out.println(another); // 未找到返回null
}
}

class Student {
public String name;
public int score;
public Student(String name, int score) {
this.name = name;
this.score = score;
}
}

通过上述代码可知:Map<K, V>是一种键-值映射表,当我们调用put(K key, V value)方法时,就把keyvalue做了映射并放入Map。当我们调用V get(K key)时,就可以通过key获取到对应的value。如果key不存在,则返回null。和List类似,Map也是一个接口,最常用的实现类是HashMap

如果只是想查询某个key是否存在,可以调用boolean containsKey(K key)方法。

如果我们在存储Map映射关系的时候,对同一个key调用两次put()方法,分别放入不同的value,会有什么问题呢?例如:

1
2
3
4
5
6
7
8
9
10
11
12
13
import java.util.HashMap;
import java.util.Map;

public class Main {
public static void main(String[] args) {
Map<String, Integer> map = new HashMap<>();
map.put("apple", 123);
map.put("pear", 456);
System.out.println(map.get("apple")); // 123
map.put("apple", 789); // 再次放入apple作为key,但value变为789
System.out.println(map.get("apple")); // 789
}
}

重复放入key-value并不会有任何问题,但是一个key只能关联一个value。在上面的代码中,一开始我们把key对象"apple"映射到Integer对象123,然后再次调用put()方法把"apple"映射到789,这时,原来关联的value对象123就被“冲掉”了。实际上,put()方法的签名是V put(K key, V value),如果放入的key已经存在,put()方法会返回被删除的旧的value,否则,返回null

始终牢记

Map中不存在重复的key,因为放入相同的key,只会把原有的key-value对应的value给替换掉。

此外,在一个Map中,虽然key不能重复,但value是可以重复的:

1
2
3
Map<String, Integer> map = new HashMap<>();
map.put("apple", 123);
map.put("pear", 123); // ok

遍历Map

Map来说,要遍历key可以使用for each循环遍历Map实例的keySet()方法返回的Set集合,它包含不重复的key的集合:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import java.util.HashMap;
import java.util.Map;

public class Main {
public static void main(String[] args) {
Map<String, Integer> map = new HashMap<>();
map.put("apple", 123);
map.put("pear", 456);
map.put("banana", 789);
for (String key : map.keySet()) {
Integer value = map.get(key);
System.out.println(key + " = " + value);
}
}
}

同时遍历keyvalue可以使用for each循环遍历Map对象的entrySet()集合,它包含每一个key-value映射:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import java.util.HashMap;
import java.util.Map;

public class Main {
public static void main(String[] args) {
Map<String, Integer> map = new HashMap<>();
map.put("apple", 123);
map.put("pear", 456);
map.put("banana", 789);
for (Map.Entry<String, Integer> entry : map.entrySet()) {
String key = entry.getKey();
Integer value = entry.getValue();
System.out.println(key + " = " + value);
}
}
}

MapList不同的是,Map存储的是key-value的映射关系,并且,它不保证顺序。在遍历的时候,遍历的顺序既不一定是put()时放入的key的顺序,也不一定是key的排序顺序。使用Map时,任何依赖顺序的逻辑都是不可靠的。以HashMap为例,假设我们放入"A""B""C"这3个key,遍历的时候,每个key会保证被遍历一次且仅遍历一次,但顺序完全没有保证,甚至对于不同的JDK版本,相同的代码遍历的输出顺序都是不同的!

注意

遍历Map时,不可假设输出的key是有序的!

练习

请编写一个根据name查找score的程序,并利用Map充当缓存,以提高查找效率:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import java.util.*;

public class Main {
public static void main(String[] args) {
List<Student> list = List.of(
new Student("Bob", 78),
new Student("Alice", 85),
new Student("Brush", 66),
new Student("Newton", 99));
var holder = new Students(list);
System.out.println(holder.getScore("Bob") == 78 ? "测试成功!" : "测试失败!");
System.out.println(holder.getScore("Alice") == 85 ? "测试成功!" : "测试失败!");
System.out.println(holder.getScore("Tom") == -1 ? "测试成功!" : "测试失败!");
}
}

class Students {
List<Student> list;
Map<String, Integer> cache;

Students(List<Student> list) {
this.list = list;
cache = new HashMap<>();
}

/**
* 根据name查找score,找到返回score,未找到返回-1
*/
int getScore(String name) {
// 先在Map中查找:
Integer score = this.cache.get(name);
if (score == null) {
// TODO:
}
return score == null ? -1 : score.intValue();
}

Integer findInList(String name) {
for (var ss : this.list) {
if (ss.name.equals(name)) {
return ss.score;
}
}
return null;
}
}

class Student {
String name;
int score;

Student(String name, int score) {
this.name = name;
this.score = score;
}
}

下载练习

小结

Map是一种映射表,可以通过key快速查找value

可以通过for each遍历keySet(),也可以通过for each遍历entrySet(),直接获取key-value

最常用的一种Map实现是HashMap

我们知道Map是一种键-值(key-value)映射表,可以通过key快速查找对应的value。

HashMap为例,观察下面的代码:

1
2
3
4
5
6
7
Map<String, Person> map = new HashMap<>();
map.put("a", new Person("Xiao Ming"));
map.put("b", new Person("Xiao Hong"));
map.put("c", new Person("Xiao Jun"));

map.get("a"); // Person("Xiao Ming")
map.get("x"); // null

HashMap之所以能根据key直接拿到value,原因是它内部通过空间换时间的方法,用一个大数组存储所有value,并根据key直接计算出value应该存储在哪个索引:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
  ┌───┐
0 │ │
├───┤
1 │ ●─┼───▶ Person("Xiao Ming")
├───┤
2 │ │
├───┤
3 │ │
├───┤
4 │ │
├───┤
5 │ ●─┼───▶ Person("Xiao Hong")
├───┤
6 │ ●─┼───▶ Person("Xiao Jun")
├───┤
7 │ │
└───┘

如果key的值为"a",计算得到的索引总是1,因此返回valuePerson("Xiao Ming"),如果key的值为"b",计算得到的索引总是5,因此返回valuePerson("Xiao Hong"),这样,就不必遍历整个数组,即可直接读取key对应的value

当我们使用key存取value的时候,就会引出一个问题:

我们放入Mapkey是字符串"a",但是,当我们获取Mapvalue时,传入的变量不一定就是放入的那个key对象。

换句话讲,两个key应该是内容相同,但不一定是同一个对象。测试代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import java.util.HashMap;
import java.util.Map;

public class Main {
public static void main(String[] args) {
String key1 = "a";
Map<String, Integer> map = new HashMap<>();
map.put(key1, 123);

String key2 = new String("a");
map.get(key2); // 123

System.out.println(key1 == key2); // false
System.out.println(key1.equals(key2)); // true
}
}

因为在Map的内部,对key做比较是通过equals()实现的,这一点和List查找元素需要正确覆写equals()是一样的,即正确使用Map必须保证:作为key的对象必须正确覆写equals()方法。

我们经常使用String作为key,因为String已经正确覆写了equals()方法。但如果我们放入的key是一个自己写的类,就必须保证正确覆写了equals()方法。

我们再思考一下HashMap为什么能通过key直接计算出value存储的索引。相同的key对象(使用equals()判断时返回true)必须要计算出相同的索引,否则,相同的key每次取出的value就不一定对。

通过key计算索引的方式就是调用key对象的hashCode()方法,它返回一个int整数。HashMap正是通过这个方法直接定位key对应的value的索引,继而直接返回value

因此,正确使用Map必须保证:

  1. 作为key的对象必须正确覆写equals()方法,相等的两个key实例调用equals()必须返回true
  2. 作为key的对象还必须正确覆写hashCode()方法,且hashCode()方法要严格遵循以下规范:
    • 如果两个对象相等,则两个对象的hashCode()必须相等;
    • 如果两个对象不相等,则两个对象的hashCode()尽量不要相等。

即对应两个实例ab

  • 如果ab相等,那么a.equals(b)一定为true,则a.hashCode()必须等于b.hashCode()
  • 如果ab不相等,那么a.equals(b)一定为false,则a.hashCode()b.hashCode()尽量不要相等。

上述第一条规范是正确性,必须保证实现,否则HashMap不能正常工作。

而第二条如果尽量满足,则可以保证查询效率,因为不同的对象,如果返回相同的hashCode(),会造成Map内部存储冲突,使存取的效率下降。

正确编写equals()的方法我们已经在编写equals方法一节中讲过了,以Person类为例:

1
2
3
4
5
public class Person {
String firstName;
String lastName;
int age;
}

把需要比较的字段找出来:

  • firstName
  • lastName
  • age

然后,引用类型使用Objects.equals()比较,基本类型使用==比较。

在正确实现equals()的基础上,我们还需要正确实现hashCode(),即上述3个字段分别相同的实例,hashCode()返回的int必须相同:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
public class Person {
String firstName;
String lastName;
int age;

@Override
int hashCode() {
int h = 0;
h = 31 * h + firstName.hashCode();
h = 31 * h + lastName.hashCode();
h = 31 * h + age;
return h;
}
}

注意到String类已经正确实现了hashCode()方法,我们在计算PersonhashCode()时,反复使用31*h,这样做的目的是为了尽量把不同的Person实例的hashCode()均匀分布到整个int范围。

和实现equals()方法遇到的问题类似,如果firstNamelastNamenull,上述代码工作起来就会抛NullPointerException。为了解决这个问题,我们在计算hashCode()的时候,经常借助Objects.hash()来计算:

1
2
3
int hashCode() {
return Objects.hash(firstName, lastName, age);
}

所以,编写equals()hashCode()遵循的原则是:

equals()用到的用于比较的每一个字段,都必须在hashCode()中用于计算;equals()中没有使用到的字段,绝不可放在hashCode()中计算。

另外注意,对于放入HashMapvalue对象,没有任何要求。

延伸阅读

既然HashMap内部使用了数组,通过计算keyhashCode()直接定位value所在的索引,那么第一个问题来了:hashCode()返回的int范围高达±21亿,先不考虑负数,HashMap内部使用的数组得有多大?

实际上HashMap初始化时默认的数组大小只有16,任何key,无论它的hashCode()有多大,都可以简单地通过:

1
int index = key.hashCode() & 0xf; // 0xf = 15

把索引确定在0 ~ 15,即永远不会超出数组范围,上述算法只是一种最简单的实现。

第二个问题:如果添加超过16个key-valueHashMap,数组不够用了怎么办?

添加超过一定数量的key-value时,HashMap会在内部自动扩容,每次扩容一倍,即长度为16的数组扩展为长度32,相应地,需要重新确定hashCode()计算的索引位置。例如,对长度为32的数组计算hashCode()对应的索引,计算方式要改为:

1
int index = key.hashCode() & 0x1f; // 0x1f = 31

由于扩容会导致重新分布已有的key-value,所以,频繁扩容对HashMap的性能影响很大。如果我们确定要使用一个容量为10000key-valueHashMap,更好的方式是创建HashMap时就指定容量:

1
Map<String, Integer> map = new HashMap<>(10000);

虽然指定容量是10000,但HashMap内部的数组长度总是2n,因此,实际数组长度被初始化为比10000大的16384(214)。

最后一个问题:如果不同的两个key,例如"a""b",它们的hashCode()恰好是相同的(这种情况是完全可能的,因为不相等的两个实例,只要求hashCode()尽量不相等),那么,当我们放入:

1
2
map.put("a", new Person("Xiao Ming"));
map.put("b", new Person("Xiao Hong"));

时,由于计算出的数组索引相同,后面放入的"Xiao Hong"会不会把"Xiao Ming"覆盖了?

当然不会!使用Map的时候,只要key不相同,它们映射的value就互不干扰。但是,在HashMap内部,确实可能存在不同的key,映射到相同的hashCode(),即相同的数组索引上,肿么办?

我们就假设"a""b"这两个key最终计算出的索引都是5,那么,在HashMap的数组中,实际存储的不是一个Person实例,而是一个List,它包含两个Entry,一个是"a"的映射,一个是"b"的映射:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
  ┌───┐
0 │ │
├───┤
1 │ │
├───┤
2 │ │
├───┤
3 │ │
├───┤
4 │ │
├───┤
5 │ ●─┼───▶ List<Entry<String, Person>>
├───┤
6 │ │
├───┤
7 │ │
└───┘

在查找的时候,例如:

1
Person p = map.get("a");

HashMap内部通过"a"找到的实际上是List<Entry<String, Person>>,它还需要遍历这个List,并找到一个Entry,它的key字段是"a",才能返回对应的Person实例。

我们把不同的key具有相同的hashCode()的情况称之为哈希冲突。在冲突的时候,一种最简单的解决办法是用List存储hashCode()相同的key-value。显然,如果冲突的概率越大,这个List就越长,Mapget()方法效率就越低,这就是为什么要尽量满足条件二:

提示

如果两个对象不相等,则两个对象的hashCode()尽量不要相等。

hashCode()方法编写得越好,HashMap工作的效率就越高。

小结

要正确使用HashMap,作为key的类必须正确覆写equals()hashCode()方法;

一个类如果覆写了equals(),就必须覆写hashCode(),并且覆写规则是:

  • 如果equals()返回true,则hashCode()返回值必须相等;
  • 如果equals()返回false,则hashCode()返回值尽量不要相等。

实现hashCode()方法可以通过Objects.hashCode()辅助方法实现。

使用EnumMap

因为HashMap是一种通过对key计算hashCode(),通过空间换时间的方式,直接定位到value所在的内部数组的索引,因此,查找效率非常高。

如果作为key的对象是enum类型,那么,还可以使用Java集合库提供的一种EnumMap,它在内部以一个非常紧凑的数组存储value,并且根据enum类型的key直接定位到内部数组的索引,并不需要计算hashCode(),不但效率最高,而且没有额外的空间浪费。

我们以DayOfWeek这个枚举类型为例,为它做一个“翻译”功能:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import java.time.DayOfWeek;
import java.util.*;

public class Main {
public static void main(String[] args) {
Map<DayOfWeek, String> map = new EnumMap<>(DayOfWeek.class);
map.put(DayOfWeek.MONDAY, "星期一");
map.put(DayOfWeek.TUESDAY, "星期二");
map.put(DayOfWeek.WEDNESDAY, "星期三");
map.put(DayOfWeek.THURSDAY, "星期四");
map.put(DayOfWeek.FRIDAY, "星期五");
map.put(DayOfWeek.SATURDAY, "星期六");
map.put(DayOfWeek.SUNDAY, "星期日");
System.out.println(map);
System.out.println(map.get(DayOfWeek.MONDAY));
}
}

使用EnumMap的时候,我们总是用Map接口来引用它,因此,实际上把HashMapEnumMap互换,在客户端看来没有任何区别。

小结

如果Map的key是enum类型,推荐使用EnumMap,既保证速度,也不浪费空间。

使用EnumMap的时候,根据面向抽象编程的原则,应持有Map接口。



我们已经知道,HashMap是一种以空间换时间的映射表,它的实现原理决定了内部的Key是无序的,即遍历HashMap的Key时,其顺序是不可预测的(但每个Key都会遍历一次且仅遍历一次)。

还有一种Map,它在内部会对Key进行排序,这种Map就是SortedMap。注意到SortedMap是接口,它的实现类是TreeMap

1
2
3
4
5
6
7
8
9
10
11
12
13
14
       ┌───┐
│Map│
└───┘

┌────┴─────┐
│ │
┌───────┐ ┌─────────┐
│HashMap│ │SortedMap│
└───────┘ └─────────┘


┌─────────┐
│ TreeMap │
└─────────┘

SortedMap保证遍历时以Key的顺序来进行排序。例如,放入的Key是"apple""pear""orange",遍历的顺序一定是"apple""orange""pear",因为String默认按字母排序:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
import java.util.*;

public class Main {
public static void main(String[] args) {
Map<String, Integer> map = new TreeMap<>();
map.put("orange", 1);
map.put("apple", 2);
map.put("pear", 3);
for (String key : map.keySet()) {
System.out.println(key);
}
// apple, orange, pear
}
}

使用TreeMap时,放入的Key必须实现Comparable接口。StringInteger这些类已经实现了Comparable接口,因此可以直接作为Key使用。作为Value的对象则没有任何要求。

如果作为Key的class没有实现Comparable接口,那么,必须在创建TreeMap时同时指定一个自定义排序算法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import java.util.*;

public class Main {
public static void main(String[] args) {
Map<Person, Integer> map = new TreeMap<>(new Comparator<Person>() {
public int compare(Person p1, Person p2) {
return p1.name.compareTo(p2.name);
}
});
map.put(new Person("Tom"), 1);
map.put(new Person("Bob"), 2);
map.put(new Person("Lily"), 3);
for (Person key : map.keySet()) {
System.out.println(key);
}
// {Person: Bob}, {Person: Lily}, {Person: Tom}
System.out.println(map.get(new Person("Bob"))); // 2
}
}

class Person {
public String name;
Person(String name) {
this.name = name;
}
public String toString() {
return "{Person: " + name + "}";
}
}

注意到Comparator接口要求实现一个比较方法,它负责比较传入的两个元素ab,如果a<b,则返回负数,通常是-1,如果a==b,则返回0,如果a>b,则返回正数,通常是1TreeMap内部根据比较结果对Key进行排序。

从上述代码执行结果可知,打印的Key确实是按照Comparator定义的顺序排序的。如果要根据Key查找Value,我们可以传入一个new Person("Bob")作为Key,它会返回对应的Integer2

另外,注意到Person类并未覆写equals()hashCode(),因为TreeMap不使用equals()hashCode()

我们来看一个稍微复杂的例子:这次我们定义了Student类,并用分数score进行排序,高分在前:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import java.util.*;

public class Main {
public static void main(String[] args) {
Map<Student, Integer> map = new TreeMap<>(new Comparator<Student>() {
public int compare(Student p1, Student p2) {
return p1.score > p2.score ? -1 : 1;
}
});
map.put(new Student("Tom", 77), 1);
map.put(new Student("Bob", 66), 2);
map.put(new Student("Lily", 99), 3);
for (Student key : map.keySet()) {
System.out.println(key);
}
System.out.println(map.get(new Student("Bob", 66))); // null?
}
}

class Student {
public String name;
public int score;
Student(String name, int score) {
this.name = name;
this.score = score;
}
public String toString() {
return String.format("{%s: score=%d}", name, score);
}
}

for循环中,我们确实得到了正确的顺序。但是,且慢!根据相同的Key:new Student("Bob", 66)进行查找时,结果为null

这是怎么肥四?难道TreeMap有问题?遇到TreeMap工作不正常时,我们首先回顾Java编程基本规则:出现问题,不要怀疑Java标准库,要从自身代码找原因。

在这个例子中,TreeMap出现问题,原因其实出在这个Comparator上:

1
2
3
public int compare(Student p1, Student p2) {
return p1.score > p2.score ? -1 : 1;
}

p1.scorep2.score不相等的时候,它的返回值是正确的,但是,在p1.scorep2.score相等的时候,它并没有返回0!这就是为什么TreeMap工作不正常的原因:TreeMap在比较两个Key是否相等时,依赖Key的compareTo()方法或者Comparator.compare()方法。在两个Key相等时,必须返回0。因此,修改代码如下:

1
2
3
4
5
6
public int compare(Student p1, Student p2) {
if (p1.score == p2.score) {
return 0;
}
return p1.score > p2.score ? -1 : 1;
}

或者直接借助Integer.compare(int, int)也可以返回正确的比较结果。

注意

使用TreeMap时,对Key的比较需要正确实现相等、大于和小于逻辑!

小结

SortedMap在遍历时严格按照Key的顺序遍历,最常用的实现类是TreeMap

作为SortedMap的Key必须实现Comparable接口,或者传入Comparator

要严格按照compare()规范实现比较逻辑,否则,TreeMap将不能正常工作。

在编写应用程序的时候,经常需要读写配置文件。例如,用户的设置:

1
2
3
4
# 上次最后打开的文件:
last_open_file=/data/hello.txt
# 自动保存文件的时间间隔:
auto_save_interval=60

配置文件的特点是,它的Key-Value一般都是String-String类型的,因此我们完全可以用Map<String, String>来表示它。

因为配置文件非常常用,所以Java集合库提供了一个Properties来表示一组“配置”。由于历史遗留原因,Properties内部本质上是一个Hashtable,但我们只需要用到Properties自身关于读写配置的接口。

读取配置文件

Properties读取配置文件非常简单。Java默认配置文件以.properties为扩展名,每行以key=value表示,以#课开头的是注释。以下是一个典型的配置文件:

1
2
3
4
# setting.properties

last_open_file=/data/hello.txt
auto_save_interval=60

可以从文件系统读取这个.properties文件:

1
2
3
4
5
6
String f = "setting.properties";
Properties props = new Properties();
props.load(new java.io.FileInputStream(f));

String filepath = props.getProperty("last_open_file");
String interval = props.getProperty("auto_save_interval", "120");

可见,用Properties读取配置文件,一共有三步:

  1. 创建Properties实例;
  2. 调用load()读取文件;
  3. 调用getProperty()获取配置。

调用getProperty()获取配置时,如果key不存在,将返回null。我们还可以提供一个默认值,这样,当key不存在的时候,就返回默认值。

也可以从classpath读取.properties文件,因为load(InputStream)方法接收一个InputStream实例,表示一个字节流,它不一定是文件流,也可以是从jar包中读取的资源流:

1
2
Properties props = new Properties();
props.load(getClass().getResourceAsStream("/common/setting.properties"));

试试从内存读取一个字节流:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
// properties
import java.io.*;
import java.util.Properties;

public class Main {
public static void main(String[] args) throws IOException {
String settings = "# test" + "\n" + "course=Java" + "\n" + "last_open_date=2019-08-07T12:35:01";
ByteArrayInputStream input = new ByteArrayInputStream(settings.getBytes("UTF-8"));
Properties props = new Properties();
props.load(input);

System.out.println("course: " + props.getProperty("course"));
System.out.println("last_open_date: " + props.getProperty("last_open_date"));
System.out.println("last_open_file: " + props.getProperty("last_open_file"));
System.out.println("auto_save: " + props.getProperty("auto_save", "60"));
}
}

如果有多个.properties文件,可以反复调用load()读取,后读取的key-value会覆盖已读取的key-value:

1
2
3
Properties props = new Properties();
props.load(getClass().getResourceAsStream("/common/setting.properties"));
props.load(new FileInputStream("C:\\conf\\setting.properties"));

上面的代码演示了Properties的一个常用用法:可以把默认配置文件放到classpath中,然后,根据机器的环境编写另一个配置文件,覆盖某些默认的配置。

Properties设计的目的是存储String类型的key-value,但Properties实际上是从Hashtable派生的,它的设计实际上是有问题的,但是为了保持兼容性,现在已经没法修改了。除了getProperty()setProperty()方法外,还有从Hashtable继承下来的get()put()方法,这些方法的参数签名是Object,我们在使用Properties的时候,不要去调用这些从Hashtable继承下来的方法。

写入配置文件

如果通过setProperty()修改了Properties实例,可以把配置写入文件,以便下次启动时获得最新配置。写入配置文件使用store()方法:

1
2
3
4
Properties props = new Properties();
props.setProperty("url", "http://www.liaoxuefeng.com");
props.setProperty("language", "Java");
props.store(new FileOutputStream("C:\\conf\\setting.properties"), "这是写入的properties注释");

编码

早期版本的Java规定.properties文件编码是ASCII编码(ISO8859-1),如果涉及到中文就必须用name=\u4e2d\u6587来表示,非常别扭。从JDK9开始,Java的.properties文件可以使用UTF-8编码了。

不过,需要注意的是,由于load(InputStream)默认总是以ASCII编码读取字节流,所以会导致读到乱码。我们需要用另一个重载方法load(Reader)读取:

1
2
Properties props = new Properties();
props.load(new FileReader("settings.properties", StandardCharsets.UTF_8));

就可以正常读取中文。InputStreamReader的区别是一个是字节流,一个是字符流。字符流在内存中已经以char类型表示了,不涉及编码问题。

小结

Java集合库提供的Properties用于读写配置文件.properties.properties文件可以使用UTF-8编码;

可以从文件系统、classpath或其他任何地方读取.properties文件;

读写Properties时,注意仅使用getProperty()setProperty()方法,不要调用继承而来的get()put()等方法。

我们知道,Map用于存储key-value的映射,对于充当key的对象,是不能重复的,并且,不但需要正确覆写equals()方法,还要正确覆写hashCode()方法。

如果我们只需要存储不重复的key,并不需要存储映射的value,那么就可以使用Set

Set用于存储不重复的元素集合,它主要提供以下几个方法:

  • 将元素添加进Set<E>boolean add(E e)
  • 将元素从Set<E>删除:boolean remove(Object e)
  • 判断是否包含元素:boolean contains(Object e)

我们来看几个简单的例子:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
import java.util.*;

public class Main {
public static void main(String[] args) {
Set<String> set = new HashSet<>();
System.out.println(set.add("abc")); // true
System.out.println(set.add("xyz")); // true
System.out.println(set.add("xyz")); // false,添加失败,因为元素已存在
System.out.println(set.contains("xyz")); // true,元素存在
System.out.println(set.contains("XYZ")); // false,元素不存在
System.out.println(set.remove("hello")); // false,删除失败,因为元素不存在
System.out.println(set.size()); // 2,一共两个元素
}
}

Set实际上相当于只存储key、不存储value的Map。我们经常用Set用于去除重复元素。

因为放入Set的元素和Map的key类似,都要正确实现equals()hashCode()方法,否则该元素无法正确地放入Set

最常用的Set实现类是HashSet,实际上,HashSet仅仅是对HashMap的一个简单封装,它的核心代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
public class HashSet<E> implements Set<E> {
// 持有一个HashMap:
private HashMap<E, Object> map = new HashMap<>();

// 放入HashMap的value:
private static final Object PRESENT = new Object();

public boolean add(E e) {
return map.put(e, PRESENT) == null;
}

public boolean contains(Object o) {
return map.containsKey(o);
}

public boolean remove(Object o) {
return map.remove(o) == PRESENT;
}
}

Set接口并不保证有序,而SortedSet接口则保证元素是有序的:

  • HashSet是无序的,因为它实现了Set接口,并没有实现SortedSet接口;
  • TreeSet是有序的,因为它实现了SortedSet接口。

用一张图表示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
       ┌───┐
│Set│
└───┘

┌────┴─────┐
│ │
┌───────┐ ┌─────────┐
│HashSet│ │SortedSet│
└───────┘ └─────────┘


┌─────────┐
│ TreeSet │
└─────────┘

我们来看HashSet的输出:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
import java.util.*;

public class Main {
public static void main(String[] args) {
Set<String> set = new HashSet<>();
set.add("apple");
set.add("banana");
set.add("pear");
set.add("orange");
for (String s : set) {
System.out.println(s);
}
}
}

注意输出的顺序既不是添加的顺序,也不是String排序的顺序,在不同版本的JDK中,这个顺序也可能是不同的。

HashSet换成TreeSet,在遍历TreeSet时,输出就是有序的,这个顺序是元素的排序顺序:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
import java.util.*;

public class Main {
public static void main(String[] args) {
Set<String> set = new TreeSet<>();
set.add("apple");
set.add("banana");
set.add("pear");
set.add("orange");
for (String s : set) {
System.out.println(s);
}
}
}

使用TreeSet和使用TreeMap的要求一样,添加的元素必须正确实现Comparable接口,如果没有实现Comparable接口,那么创建TreeSet时必须传入一个Comparator对象。

练习

在聊天软件中,发送方发送消息时,遇到网络超时后就会自动重发,因此,接收方可能会收到重复的消息,在显示给用户看的时候,需要首先去重。请练习使用Set去除重复的消息:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import java.util.*;

public class Main {
public static void main(String[] args) {
List<Message> received = List.of(
new Message(1, "Hello!"),
new Message(2, "发工资了吗?"),
new Message(2, "发工资了吗?"),
new Message(3, "去哪吃饭?"),
new Message(3, "去哪吃饭?"),
new Message(4, "Bye")
);
List<Message> displayMessages = process(received);
for (Message message : displayMessages) {
System.out.println(message.text);
}
}

static List<Message> process(List<Message> received) {
// TODO: 按sequence去除重复消息
return received;
}
}

class Message {
public final int sequence;
public final String text;
public Message(int sequence, String text) {
this.sequence = sequence;
this.text = text;
}
}

小结

Set用于存储不重复的元素集合:

  • 放入HashSet的元素与作为HashMap的key要求相同;
  • 放入TreeSet的元素与作为TreeMap的Key要求相同。

利用Set可以去除重复元素;

遍历SortedSet按照元素的排序顺序遍历,也可以自定义排序算法。

队列(Queue)是一种经常使用的集合。Queue实际上是实现了一个先进先出(FIFO:First In First Out)的有序表。它和List的区别在于,List可以在任意位置添加和删除元素,而Queue只有两个操作:

  • 把元素添加到队列末尾;
  • 从队列头部取出元素。

超市的收银台就是一个队列:

queue

在Java的标准库中,队列接口Queue定义了以下几个方法:

  • int size():获取队列长度;
  • boolean add(E)/boolean offer(E):添加元素到队尾;
  • E remove()/E poll():获取队首元素并从队列中删除;
  • E element()/E peek():获取队首元素但并不从队列中删除。

对于具体的实现类,有的Queue有最大队列长度限制,有的Queue没有。注意到添加、删除和获取队列元素总是有两个方法,这是因为在添加或获取元素失败时,这两个方法的行为是不同的。我们用一个表格总结如下:

throw Exception 返回false或null
添加元素到队尾 add(E e) boolean offer(E e)
取队首元素并删除 E remove() E poll()
取队首元素但不删除 E element() E peek()

举个栗子,假设我们有一个队列,对它做一个添加操作,如果调用add()方法,当添加失败时(可能超过了队列的容量),它会抛出异常:

1
2
3
4
5
6
7
Queue<String> q = ...
try {
q.add("Apple");
System.out.println("添加成功");
} catch(IllegalStateException e) {
System.out.println("添加失败");
}

如果我们调用offer()方法来添加元素,当添加失败时,它不会抛异常,而是返回false

1
2
3
4
5
6
Queue<String> q = ...
if (q.offer("Apple")) {
System.out.println("添加成功");
} else {
System.out.println("添加失败");
}

当我们需要从Queue中取出队首元素时,如果当前Queue是一个空队列,调用remove()方法,它会抛出异常:

1
2
3
4
5
6
7
Queue<String> q = ...
try {
String s = q.remove();
System.out.println("获取成功");
} catch(IllegalStateException e) {
System.out.println("获取失败");
}

如果我们调用poll()方法来取出队首元素,当获取失败时,它不会抛异常,而是返回null

1
2
3
4
5
6
7
Queue<String> q = ...
String s = q.poll();
if (s != null) {
System.out.println("获取成功");
} else {
System.out.println("获取失败");
}

因此,两套方法可以根据需要来选择使用。

注意:不要把null添加到队列中,否则poll()方法返回null时,很难确定是取到了null元素还是队列为空。

接下来我们以poll()peek()为例来说说“获取并删除”与“获取但不删除”的区别。对于Queue来说,每次调用poll(),都会获取队首元素,并且获取到的元素已经从队列中被删除了:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import java.util.LinkedList;
import java.util.Queue;

public class Main {
public static void main(String[] args) {
Queue<String> q = new LinkedList<>();
// 添加3个元素到队列:
q.offer("apple");
q.offer("pear");
q.offer("banana");
// 从队列取出元素:
System.out.println(q.poll()); // apple
System.out.println(q.poll()); // pear
System.out.println(q.poll()); // banana
System.out.println(q.poll()); // null,因为队列是空的
}
}

如果用peek(),因为获取队首元素时,并不会从队列中删除这个元素,所以可以反复获取:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import java.util.LinkedList;
import java.util.Queue;

public class Main {
public static void main(String[] args) {
Queue<String> q = new LinkedList<>();
// 添加3个元素到队列:
q.offer("apple");
q.offer("pear");
q.offer("banana");
// 队首永远都是apple,因为peek()不会删除它:
System.out.println(q.peek()); // apple
System.out.println(q.peek()); // apple
System.out.println(q.peek()); // apple
}
}

从上面的代码中,我们还可以发现,LinkedList即实现了List接口,又实现了Queue接口,但是,在使用的时候,如果我们把它当作List,就获取List的引用,如果我们把它当作Queue,就获取Queue的引用:

1
2
3
4
// 这是一个List:
List<String> list = new LinkedList<>();
// 这是一个Queue:
Queue<String> queue = new LinkedList<>();

始终按照面向抽象编程的原则编写代码,可以大大提高代码的质量。

小结

队列Queue实现了一个先进先出(FIFO)的数据结构:

  • 通过add()/offer()方法将元素添加到队尾;
  • 通过remove()/poll()从队首获取元素并删除;
  • 通过element()/peek()从队首获取元素但不删除。

要避免把null添加到队列。

我们知道,Queue是一个先进先出(FIFO)的队列。

在银行柜台办业务时,我们假设只有一个柜台在办理业务,但是办理业务的人很多,怎么办?

可以每个人先取一个号,例如:A1A2A3……然后,按照号码顺序依次办理,实际上这就是一个Queue

如果这时来了一个VIP客户,他的号码是V1,虽然当前排队的是A10A11A12……但是柜台下一个呼叫的客户号码却是V1

这个时候,我们发现,要实现“VIP插队”的业务,用Queue就不行了,因为Queue会严格按FIFO的原则取出队首元素。我们需要的是优先队列:PriorityQueue

PriorityQueueQueue的区别在于,它的出队顺序与元素的优先级有关,对PriorityQueue调用remove()poll()方法,返回的总是优先级最高的元素。

要使用PriorityQueue,我们就必须给每个元素定义“优先级”。我们以实际代码为例,先看看PriorityQueue的行为:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import java.util.PriorityQueue;
import java.util.Queue;

public class Main {
public static void main(String[] args) {
Queue<String> q = new PriorityQueue<>();
// 添加3个元素到队列:
q.offer("apple");
q.offer("pear");
q.offer("banana");
System.out.println(q.poll()); // apple
System.out.println(q.poll()); // banana
System.out.println(q.poll()); // pear
System.out.println(q.poll()); // null,因为队列为空
}
}

我们放入的顺序是"apple""pear""banana",但是取出的顺序却是"apple""banana""pear",这是因为从字符串的排序看,"apple"排在最前面,"pear"排在最后面。

因此,放入PriorityQueue的元素,必须实现Comparable接口,PriorityQueue会根据元素的排序顺序决定出队的优先级。

如果我们要放入的元素并没有实现Comparable接口怎么办?PriorityQueue允许我们提供一个Comparator对象来判断两个元素的顺序。我们以银行排队业务为例,实现一个PriorityQueue

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import java.util.Comparator;
import java.util.PriorityQueue;
import java.util.Queue;

public class Main {
public static void main(String[] args) {
Queue<User> q = new PriorityQueue<>(new UserComparator());
// 添加3个元素到队列:
q.offer(new User("Bob", "A1"));
q.offer(new User("Alice", "A2"));
q.offer(new User("Boss", "V1"));
System.out.println(q.poll()); // Boss/V1
System.out.println(q.poll()); // Bob/A1
System.out.println(q.poll()); // Alice/A2
System.out.println(q.poll()); // null,因为队列为空
}
}

class UserComparator implements Comparator<User> {
public int compare(User u1, User u2) {
if (u1.number.charAt(0) == u2.number.charAt(0)) {
// 如果两人的号都是A开头或者都是V开头,比较号的大小:
return u1.number.compareTo(u2.number);
}
if (u1.number.charAt(0) == 'V') {
// u1的号码是V开头,优先级高:
return -1;
} else {
return 1;
}
}
}

class User {
public final String name;
public final String number;

public User(String name, String number) {
this.name = name;
this.number = number;
}

public String toString() {
return name + "/" + number;
}
}

实现PriorityQueue的关键在于提供的UserComparator对象,它负责比较两个元素的大小(较小的在前)。UserComparator总是把V开头的号码优先返回,只有在开头相同的时候,才比较号码大小。

上面的UserComparator的比较逻辑其实还是有问题的,它会把A10排在A2的前面,请尝试修复该错误。

小结

PriorityQueue实现了一个优先队列:从队首获取元素时,总是获取优先级最高的元素;

PriorityQueue默认按元素比较的顺序排序(必须实现Comparable接口),也可以通过Comparator自定义排序算法(元素就不必实现Comparable接口)。

使用Deque

我们知道,Queue是队列,只能一头进,另一头出。

如果把条件放松一下,允许两头都进,两头都出,这种队列叫双端队列(Double Ended Queue),学名Deque

Java集合提供了接口Deque来实现一个双端队列,它的功能是:

  • 既可以添加到队尾,也可以添加到队首;
  • 既可以从队首获取,又可以从队尾获取。

我们来比较一下QueueDeque出队和入队的方法:

Queue Deque
添加元素到队尾 add(E e) / offer(E e) addLast(E e) / offerLast(E e)
取队首元素并删除 E remove() / E poll() E removeFirst() / E pollFirst()
取队首元素但不删除 E element() / E peek() E getFirst() / E peekFirst()
添加元素到队首 addFirst(E e) / offerFirst(E e)
取队尾元素并删除 E removeLast() / E pollLast()
取队尾元素但不删除 E getLast() / E peekLast()

对于添加元素到队尾的操作,Queue提供了add()/offer()方法,而Deque提供了addLast()/offerLast()方法。添加元素到队首、取队尾元素的操作在Queue中不存在,在Deque中由addFirst()/removeLast()等方法提供。

注意到Deque接口实际上扩展自Queue

1
2
3
public interface Deque<E> extends Queue<E> {
...
}

因此,Queue提供的add()/offer()方法在Deque中也可以使用,但是,使用Deque,最好不要调用offer(),而是调用offerLast()

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import java.util.Deque;
import java.util.LinkedList;

public class Main {
public static void main(String[] args) {
Deque<String> deque = new LinkedList<>();
deque.offerLast("A"); // A
deque.offerLast("B"); // A <- B
deque.offerFirst("C"); // C <- A <- B
System.out.println(deque.pollFirst()); // C, 剩下A <- B
System.out.println(deque.pollLast()); // B, 剩下A
System.out.println(deque.pollFirst()); // A
System.out.println(deque.pollFirst()); // null
}
}

如果直接写deque.offer(),我们就需要思考,offer()实际上是offerLast(),我们明确地写上offerLast(),不需要思考就能一眼看出这是添加到队尾。

因此,使用Deque,推荐总是明确调用offerLast()/offerFirst()或者pollFirst()/pollLast()方法。

Deque是一个接口,它的实现类有ArrayDequeLinkedList

我们发现LinkedList真是一个全能选手,它即是List,又是Queue,还是Deque。但是我们在使用的时候,总是用特定的接口来引用它,这是因为持有接口说明代码的抽象层次更高,而且接口本身定义的方法代表了特定的用途。

1
2
3
4
5
6
// 不推荐的写法:
LinkedList<String> d1 = new LinkedList<>();
d1.offerLast("z");
// 推荐的写法:
Deque<String> d2 = new LinkedList<>();
d2.offerLast("z");

可见面向抽象编程的一个原则就是:尽量持有接口,而不是具体的实现类。

小结

Deque实现了一个双端队列(Double Ended Queue),它可以:

  • 将元素添加到队尾或队首:addLast()/offerLast()/addFirst()/offerFirst()
  • 从队首/队尾获取元素并删除:removeFirst()/pollFirst()/removeLast()/pollLast()
  • 从队首/队尾获取元素但不删除:getFirst()/peekFirst()/getLast()/peekLast()
  • 总是调用xxxFirst()/xxxLast()以便与Queue的方法区分开;
  • 避免把null添加到队列。


栈(Stack)是一种后进先出(LIFO:Last In First Out)的数据结构。

什么是LIFO呢?我们先回顾一下Queue的特点FIFO:

1
2
3
4
5
           ────────────────────────
(\(\ (\(\ (\(\ (\(\ (\(\
(='.') ──▶ (='.') (='.') (='.') ──▶ (='.')
O(_")") O(_")") O(_")") O(_")") O(_")")
────────────────────────

所谓FIFO,是最先进队列的元素一定最早出队列,而LIFO是最后进Stack的元素一定最早出Stack。如何做到这一点呢?只需要把队列的一端封死:

1
2
3
4
5
            ───────────────────────────────┐
(\(\ (\(\ (\(\ (\(\ (\(\ │
(='.') ◀──▶ (='.') (='.') (='.') (='.')│
O(_")") O(_")") O(_")") O(_")") O(_")")│
───────────────────────────────┘

因此,Stack是这样一种数据结构:只能不断地往Stack中压入(push)元素,最后进去的必须最早弹出(pop)来:

donuts-stack

Stack只有入栈和出栈的操作:

  • 把元素压栈:push(E)
  • 把栈顶的元素“弹出”:pop()
  • 取栈顶元素但不弹出:peek()

在Java中,我们用Deque可以实现Stack的功能:

  • 把元素压栈:push(E)/addFirst(E)
  • 把栈顶的元素“弹出”:pop()/removeFirst()
  • 取栈顶元素但不弹出:peek()/peekFirst()

为什么Java的集合类没有单独的Stack接口呢?因为有个遗留类名字就叫Stack,出于兼容性考虑,所以没办法创建Stack接口,只能用Deque接口来“模拟”一个Stack了。

当我们把Deque作为Stack使用时,注意只调用push()/pop()/peek()方法,不要调用addFirst()/removeFirst()/peekFirst()方法,这样代码更加清晰。

Stack的作用

Stack在计算机中使用非常广泛,JVM在处理Java方法调用的时候就会通过栈这种数据结构维护方法调用的层次。例如:

1
2
3
4
5
6
7
8
9
10
11
static void main(String[] args) {
foo(123);
}

static String foo(x) {
return "F-" + bar(x + 1);
}

static int bar(int x) {
return x << 2;
}

JVM会创建方法调用栈,每调用一个方法时,先将参数压栈,然后执行对应的方法;当方法返回时,返回值压栈,调用方法通过出栈操作获得方法返回值。

因为方法调用栈有容量限制,嵌套调用过多会造成栈溢出,即引发StackOverflowError

1
2
3
4
5
6
7
8
9
10
// 测试无限递归调用
public class Main {
public static void main(String[] args) {
increase(1);
}

static int increase(int x) {
return increase(x) + 1;
}
}

我们再来看一个Stack的用途:对整数进行进制的转换就可以利用栈。

例如,我们要把一个int整数12500转换为十六进制表示的字符串,如何实现这个功能?

首先我们准备一个空栈:

1
2
3
4
5
6
│   │
│ │
│ │
│ │
│ │
└───┘

然后计算12500÷16=781…4,余数是4,把余数4压栈:

1
2
3
4
5
6
│   │
│ │
│ │
│ │
│ 4 │
└───┘

然后计算781÷16=48…13,余数是1313的十六进制用字母D表示,把余数D压栈:

1
2
3
4
5
6
│   │
│ │
│ │
│ D │
│ 4 │
└───┘

然后计算48÷16=3…0,余数是0,把余数0压栈:

1
2
3
4
5
6
│   │
│ │
│ 0 │
│ D │
│ 4 │
└───┘

最后计算3÷16=0…3,余数是3,把余数3压栈:

1
2
3
4
5
6
│   │
│ 3 │
│ 0 │
│ D │
│ 4 │
└───┘

当商是0的时候,计算结束,我们把栈的所有元素依次弹出,组成字符串30D4,这就是十进制整数12500的十六进制表示的字符串。

计算中缀表达式

在编写程序的时候,我们使用的带括号的数学表达式实际上是中缀表达式,即运算符在中间,例如:1 + 2 * (9 - 5)

但是计算机执行表达式的时候,它并不能直接计算中缀表达式,而是通过编译器把中缀表达式转换为后缀表达式,例如:1 2 9 5 - * +

这个编译过程就会用到栈。我们先跳过编译这一步(涉及运算优先级,代码比较复杂),看看如何通过栈计算后缀表达式。

计算后缀表达式不考虑优先级,直接从左到右依次计算,因此计算起来简单。首先准备一个空的栈:

1
2
3
4
5
6
│   │
│ │
│ │
│ │
│ │
└───┘

然后我们依次扫描后缀表达式1 2 9 5 - * +,遇到数字1,就直接扔到栈里:

1
2
3
4
5
6
│   │
│ │
│ │
│ │
│ 1 │
└───┘

紧接着,遇到数字295,也扔到栈里:

1
2
3
4
5
6
│   │
│ 5 │
│ 9 │
│ 2 │
│ 1 │
└───┘

接下来遇到减号时,弹出栈顶的两个元素,并计算9-5=4,把结果4压栈:

1
2
3
4
5
6
│   │
│ │
│ 4 │
│ 2 │
│ 1 │
└───┘

接下来遇到*号时,弹出栈顶的两个元素,并计算2*4=8,把结果8压栈:

1
2
3
4
5
6
│   │
│ │
│ │
│ 8 │
│ 1 │
└───┘

接下来遇到+号时,弹出栈顶的两个元素,并计算1+8=9,把结果9压栈:

1
2
3
4
5
6
│   │
│ │
│ │
│ │
│ 9 │
└───┘

扫描结束后,没有更多的计算了,弹出栈的唯一一个元素,得到计算结果9

练习

请利用Stack把一个给定的整数转换为十六进制:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
// 转十六进制
import java.util.*;

public class Main {
public static void main(String[] args) {
String hex = toHex(12500);
if (hex.equalsIgnoreCase("30D4")) {
System.out.println("测试通过");
} else {
System.out.println("测试失败");
}
}

static String toHex(int n) {
return "";
}
}

进阶练习:

请利用Stack把字符串中缀表达式编译为后缀表达式,然后再利用栈执行后缀表达式获得计算结果:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
// 高难度练习,慎重选择!
import java.util.*;

public class Main {
public static void main(String[] args) {
String exp = "1 + 2 * (9 - 5)";
SuffixExpression se = compile(exp);
int result = se.execute();
System.out.println(exp + " = " + result + " " + (result == 1 + 2 * (9 - 5) ? "✓" : "✗"));
}

static SuffixExpression compile(String exp) {
// TODO:
return new SuffixExpression();
}
}

class SuffixExpression {
int execute() {
// TODO:
return 0;
}
}

进阶练习2:

请把带变量的中缀表达式编译为后缀表达式,执行后缀表达式时,传入变量的值并获得计算结果:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
// 超高难度练习,慎重选择!
import java.util.*;

public class Main {
public static void main(String[] args) {
String exp = "x + 2 * (y - 5)";
SuffixExpression se = compile(exp);
Map<String, Integer> env = Map.of("x", 1, "y", 9);
int result = se.execute(env);
System.out.println(exp + " = " + result + " " + (result == 1 + 2 * (9 - 5) ? "✓" : "✗"));
}

static SuffixExpression compile(String exp) {
// TODO:
return new SuffixExpression();
}
}

class SuffixExpression {
int execute(Map<String, Integer> env) {
// TODO:
return 0;
}
}

下载练习

小结

栈(Stack)是一种后进先出(LIFO)的数据结构,操作栈的元素的方法有:

  • 把元素压栈:push(E)
  • 把栈顶的元素“弹出”:pop(E)
  • 取栈顶元素但不弹出:peek(E)

在Java中,我们用Deque可以实现Stack的功能,注意只调用push()/pop()/peek()方法,避免调用Deque的其他方法;

不要使用遗留类Stack

Java的集合类都可以使用for each循环,ListSetQueue会迭代每个元素,Map会迭代每个key。以List为例:

1
2
3
4
List<String> list = List.of("Apple", "Orange", "Pear");
for (String s : list) {
System.out.println(s);
}

实际上,Java编译器并不知道如何遍历List。上述代码能够编译通过,只是因为编译器把for each循环通过Iterator改写为了普通的for循环:

1
2
3
4
for (Iterator<String> it = list.iterator(); it.hasNext(); ) {
String s = it.next();
System.out.println(s);
}

我们把这种通过Iterator对象遍历集合的模式称为迭代器。

使用迭代器的好处在于,调用方总是以统一的方式遍历各种集合类型,而不必关心它们内部的存储结构。

例如,我们虽然知道ArrayList在内部是以数组形式存储元素,并且,它还提供了get(int)方法。虽然我们可以用for循环遍历:

1
2
3
for (int i=0; i<list.size(); i++) {
Object value = list.get(i);
}

但是这样一来,调用方就必须知道集合的内部存储结构。并且,如果把ArrayList换成LinkedListget(int)方法耗时会随着index的增加而增加。如果把ArrayList换成Set,上述代码就无法编译,因为Set内部没有索引。

Iterator遍历就没有上述问题,因为Iterator对象是集合对象自己在内部创建的,它自己知道如何高效遍历内部的数据集合,调用方则获得了统一的代码,编译器才能把标准的for each循环自动转换为Iterator遍历。

如果我们自己编写了一个集合类,想要使用for each循环,只需满足以下条件:

  • 集合类实现Iterable接口,该接口要求返回一个Iterator对象;
  • Iterator对象迭代集合内部数据。

这里的关键在于,集合类通过调用iterator()方法,返回一个Iterator对象,这个对象必须自己知道如何遍历该集合。

一个简单的Iterator示例如下,它总是以倒序遍历集合:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
// Iterator
import java.util.*;

public class Main {
public static void main(String[] args) {
ReverseList<String> rlist = new ReverseList<>();
rlist.add("Apple");
rlist.add("Orange");
rlist.add("Pear");
for (String s : rlist) {
System.out.println(s);
}
}
}

class ReverseList<T> implements Iterable<T> {

private List<T> list = new ArrayList<>();

public void add(T t) {
list.add(t);
}

@Override
public Iterator<T> iterator() {
return new ReverseIterator(list.size());
}

class ReverseIterator implements Iterator<T> {
int index;

ReverseIterator(int index) {
this.index = index;
}

@Override
public boolean hasNext() {
return index > 0;
}

@Override
public T next() {
index--;
return ReverseList.this.list.get(index);
}
}
}

虽然ReverseListReverseIterator的实现类稍微比较复杂,但是,注意到这是底层集合库,只需编写一次。而调用方则完全按for each循环编写代码,根本不需要知道集合内部的存储逻辑和遍历逻辑。

在编写Iterator的时候,我们通常可以用一个内部类来实现Iterator接口,这个内部类可以直接访问对应的外部类的所有字段和方法。例如,上述代码中,内部类ReverseIterator可以用ReverseList.this获得当前外部类的this引用,然后,通过这个this引用就可以访问ReverseList的所有字段和方法。

小结

Iterator是一种抽象的数据访问模型。使用Iterator模式进行迭代的好处有:

  • 对任何集合都采用同一种访问模型;
  • 调用者对集合内部结构一无所知;
  • 集合类返回的Iterator对象知道如何迭代。

Java提供了标准的迭代器模型,即集合类实现java.util.Iterable接口,返回java.util.Iterator实例。

Collections是JDK提供的工具类,同样位于java.util包中。它提供了一系列静态方法,能更方便地操作各种集合。

注意

Collections结尾多了一个s,不是Collection!

我们一般看方法名和参数就可以确认Collections提供的该方法的功能。例如,对于以下静态方法:

1
public static boolean addAll(Collection<? super T> c, T... elements) { ... }

addAll()方法可以给一个Collection类型的集合添加若干元素。因为方法签名是Collection,所以我们可以传入ListSet等各种集合类型。

创建空集合

对于旧版的JDK,可以使用Collections提供的一系列方法来创建空集合:

  • 创建空List:List<T> emptyList()
  • 创建空Map:Map<K, V> emptyMap()
  • 创建空Set:Set<T> emptySet()

要注意到返回的空集合是不可变集合,无法向其中添加或删除元素。

新版的JDK≥9可以直接使用List.of()Map.of()Set.of()来创建空集合。

创建单元素集合

对于旧版的JDK,Collections提供了一系列方法来创建一个单元素集合:

  • 创建一个元素的List:List<T> singletonList(T o)
  • 创建一个元素的Map:Map<K, V> singletonMap(K key, V value)
  • 创建一个元素的Set:Set<T> singleton(T o)

要注意到返回的单元素集合也是不可变集合,无法向其中添加或删除元素。

新版的JDK≥9可以直接使用List.of(T...)Map.of(T...)Set.of(T...)来创建任意个元素的集合。

排序

Collections可以对List进行排序。因为排序会直接修改List元素的位置,因此必须传入可变List

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import java.util.*;

public class Main {
public static void main(String[] args) {
List<String> list = new ArrayList<>();
list.add("apple");
list.add("pear");
list.add("orange");
// 排序前:
System.out.println(list);
Collections.sort(list);
// 排序后:
System.out.println(list);
}
}

洗牌

Collections提供了洗牌算法,即传入一个有序的List,可以随机打乱List内部元素的顺序,效果相当于让计算机洗牌:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import java.util.*;

public class Main {
public static void main(String[] args) {
List<Integer> list = new ArrayList<>();
for (int i=0; i<10; i++) {
list.add(i);
}
// 洗牌前:
System.out.println(list);
// 洗牌:
Collections.shuffle(list);
// 洗牌后:
System.out.println(list);
}
}

不可变集合

Collections还提供了一组方法把可变集合封装成不可变集合:

  • 封装成不可变List:List<T> unmodifiableList(List<? extends T> list)
  • 封装成不可变Set:Set<T> unmodifiableSet(Set<? extends T> set)
  • 封装成不可变Map:Map<K, V> unmodifiableMap(Map<? extends K, ? extends V> m)

这种封装实际上是通过创建一个代理对象,拦截掉所有修改方法实现的。我们来看看效果:

1
2
3
4
5
6
7
8
9
10
11
12
import java.util.*;

public class Main {
public static void main(String[] args) {
List<String> mutable = new ArrayList<>();
mutable.add("apple");
mutable.add("pear");
// 变为不可变集合:
List<String> immutable = Collections.unmodifiableList(mutable);
immutable.add("orange"); // UnsupportedOperationException!
}
}

然而,继续对原始的可变List进行增删是可以的,并且,会直接影响到封装后的“不可变”List

1
2
3
4
5
6
7
8
9
10
11
12
13
import java.util.*;

public class Main {
public static void main(String[] args) {
List<String> mutable = new ArrayList<>();
mutable.add("apple");
mutable.add("pear");
// 变为不可变集合:
List<String> immutable = Collections.unmodifiableList(mutable);
mutable.add("orange");
System.out.println(immutable);
}
}

因此,如果我们希望把一个可变List封装成不可变List,那么,返回不可变List后,最好立刻扔掉可变List的引用,这样可以保证后续操作不会意外改变原始对象,从而造成“不可变”List变化了:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
import java.util.*;

public class Main {
public static void main(String[] args) {
List<String> mutable = new ArrayList<>();
mutable.add("apple");
mutable.add("pear");
// 变为不可变集合:
List<String> immutable = Collections.unmodifiableList(mutable);
// 立刻扔掉mutable的引用:
mutable = null;
System.out.println(immutable);
}
}

线程安全集合

Collections还提供了一组方法,可以把线程不安全的集合变为线程安全的集合:

  • 变为线程安全的List:List<T> synchronizedList(List<T> list)
  • 变为线程安全的Set:Set<T> synchronizedSet(Set<T> s)
  • 变为线程安全的Map:Map<K,V> synchronizedMap(Map<K,V> m)

多线程的概念我们会在后面讲。因为从Java 5开始,引入了更高效的并发集合类,所以上述这几个同步方法已经没有什么用了。

小结

Collections类提供了一组工具方法来方便使用集合类:

  • 创建空集合;
  • 创建单元素集合;
  • 创建不可变集合;
  • 排序/洗牌等操作。

留言與分享

JAVA-泛型

分類 编程语言, Java

在讲解什么是泛型之前,我们先观察Java标准库提供的ArrayList,它可以看作“可变长度”的数组,因为用起来比数组更方便。

实际上ArrayList内部就是一个Object[]数组,配合存储一个当前分配的长度,就可以充当“可变数组”:

1
2
3
4
5
6
7
public class ArrayList {
private Object[] array;
private int size;
public void add(Object e) {...}
public void remove(int index) {...}
public Object get(int index) {...}
}

如果用上述ArrayList存储String类型,会有这么几个缺点:

  • 需要强制转型;
  • 不方便,易出错。

例如,代码必须这么写:

1
2
3
4
ArrayList list = new ArrayList();
list.add("Hello");
// 获取到Object,必须强制转型为String:
String first = (String) list.get(0);

很容易出现ClassCastException,因为容易“误转型”:

1
2
3
list.add(new Integer(123));
// ERROR: ClassCastException:
String second = (String) list.get(1);

要解决上述问题,我们可以为String单独编写一种ArrayList

1
2
3
4
5
6
7
public class StringArrayList {
private String[] array;
private int size;
public void add(String e) {...}
public void remove(int index) {...}
public String get(int index) {...}
}

这样一来,存入的必须是String,取出的也一定是String,不需要强制转型,因为编译器会强制检查放入的类型:

1
2
3
4
5
StringArrayList list = new StringArrayList();
list.add("Hello");
String first = list.get(0);
// 编译错误: 不允许放入非String类型:
list.add(new Integer(123));

问题暂时解决。

然而,新的问题是,如果要存储Integer,还需要为Integer单独编写一种ArrayList

1
2
3
4
5
6
7
public class IntegerArrayList {
private Integer[] array;
private int size;
public void add(Integer e) {...}
public void remove(int index) {...}
public Integer get(int index) {...}
}

实际上,还需要为其他所有class单独编写一种ArrayList

  • LongArrayList
  • DoubleArrayList
  • PersonArrayList

这是不可能的,JDK的class就有上千个,而且它还不知道其他人编写的class。

为了解决新的问题,我们必须把ArrayList变成一种模板:ArrayList<T>,代码如下:

1
2
3
4
5
6
7
public class ArrayList<T> {
private T[] array;
private int size;
public void add(T e) {...}
public void remove(int index) {...}
public T get(int index) {...}
}

T可以是任何class。这样一来,我们就实现了:编写一次模版,可以创建任意类型的ArrayList

1
2
3
4
5
6
// 创建可以存储String的ArrayList:
ArrayList<String> strList = new ArrayList<String>();
// 创建可以存储Float的ArrayList:
ArrayList<Float> floatList = new ArrayList<Float>();
// 创建可以存储Person的ArrayList:
ArrayList<Person> personList = new ArrayList<Person>();

因此,泛型就是定义一种模板,例如ArrayList<T>,然后在代码中为用到的类创建对应的ArrayList<类型>

1
ArrayList<String> strList = new ArrayList<String>();

由编译器针对类型作检查:

1
2
3
4
strList.add("hello"); // OK
String s = strList.get(0); // OK
strList.add(new Integer(123)); // compile error!
Integer n = strList.get(0); // compile error!

这样一来,既实现了编写一次,万能匹配,又通过编译器保证了类型安全:这就是泛型。

向上转型

在Java标准库中的ArrayList<T>实现了List<T>接口,它可以向上转型为List<T>

1
2
3
4
5
public class ArrayList<T> implements List<T> {
...
}

List<String> list = new ArrayList<String>();

即类型ArrayList<T>可以向上转型为List<T>

特别注意:不能把ArrayList<Integer>向上转型为ArrayList<Number>List<Number>

这是为什么呢?假设ArrayList<Integer>可以向上转型为ArrayList<Number>,观察一下代码:

1
2
3
4
5
6
7
8
9
10
// 创建ArrayList<Integer>类型:
ArrayList<Integer> integerList = new ArrayList<Integer>();
// 添加一个Integer:
integerList.add(new Integer(123));
// “向上转型”为ArrayList<Number>:
ArrayList<Number> numberList = integerList;
// 添加一个Float,因为Float也是Number:
numberList.add(new Float(12.34));
// 从ArrayList<Integer>获取索引为1的元素(即添加的Float):
Integer n = integerList.get(1); // ClassCastException!

我们把一个ArrayList<Integer>转型为ArrayList<Number>类型后,这个ArrayList<Number>就可以接受Float类型,因为FloatNumber的子类。但是,ArrayList<Number>实际上和ArrayList<Integer>是同一个对象,也就是ArrayList<Integer>类型,它不可能接受Float类型, 所以在获取Integer的时候将产生ClassCastException

实际上,编译器为了避免这种错误,根本就不允许把ArrayList<Integer>转型为ArrayList<Number>

特别注意

ArrayList<Integer>ArrayList<Number>两者完全没有继承关系。

用一个图来表示泛型的继承关系,就是T不变时,可以向上转型,T本身不能向上转型:

1
2
3
4
5
6
  List<Integer>     ArrayList<Number>
▲ ▲
│ │
│ X
│ │
ArrayList<Integer> ArrayList<Integer>

小结

泛型就是编写模板代码来适应任意类型;

泛型的好处是使用时不必对类型进行强制转换,它通过编译器对类型进行检查;

注意泛型的继承关系:可以把ArrayList<Integer>向上转型为List<Integer>T不能变!),但不能把ArrayList<Integer>向上转型为ArrayList<Number>T不能变成父类)。

使用ArrayList时,如果不定义泛型类型时,泛型类型实际上就是Object

1
2
3
4
5
6
// 编译器警告:
List list = new ArrayList();
list.add("Hello");
list.add("World");
String first = (String) list.get(0);
String second = (String) list.get(1);

此时,只能把<T>当作Object使用,没有发挥泛型的优势。

当我们定义泛型类型<String>后,List<T>的泛型接口变为强类型List<String>

1
2
3
4
5
6
7
// 无编译器警告:
List<String> list = new ArrayList<String>();
list.add("Hello");
list.add("World");
// 无强制转型:
String first = list.get(0);
String second = list.get(1);

当我们定义泛型类型<Number>后,List<T>的泛型接口变为强类型List<Number>

1
2
3
4
5
List<Number> list = new ArrayList<Number>();
list.add(new Integer(123));
list.add(new Double(12.34));
Number first = list.get(0);
Number second = list.get(1);

编译器如果能自动推断出泛型类型,就可以省略后面的泛型类型。例如,对于下面的代码:

1
List<Number> list = new ArrayList<Number>();

编译器看到泛型类型List<Number>就可以自动推断出后面的ArrayList<T>的泛型类型必须是ArrayList<Number>,因此,可以把代码简写为:

1
2
// 可以省略后面的Number,编译器可以自动推断泛型类型:
List<Number> list = new ArrayList<>();

泛型接口

除了ArrayList<T>使用了泛型,还可以在接口中使用泛型。例如,Arrays.sort(Object[])可以对任意数组进行排序,但待排序的元素必须实现Comparable<T>这个泛型接口:

1
2
3
4
5
6
7
8
public interface Comparable<T> {
/**
* 返回负数: 当前实例比参数o小
* 返回0: 当前实例与参数o相等
* 返回正数: 当前实例比参数o大
*/
int compareTo(T o);
}

可以直接对String数组进行排序:

1
2
3
4
5
6
7
8
9
10
// sort
import java.util.Arrays;

public class Main {
public static void main(String[] args) {
String[] ss = new String[] { "Orange", "Apple", "Pear" };
Arrays.sort(ss);
System.out.println(Arrays.toString(ss));
}
}

这是因为String本身已经实现了Comparable<String>接口。如果换成我们自定义的Person类型试试:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
// sort
import java.util.Arrays;

public class Main {
public static void main(String[] args) {
Person[] ps = new Person[] {
new Person("Bob", 61),
new Person("Alice", 88),
new Person("Lily", 75),
};
Arrays.sort(ps);
System.out.println(Arrays.toString(ps));
}
}

class Person {
String name;
int score;
Person(String name, int score) {
this.name = name;
this.score = score;
}
public String toString() {
return this.name + "," + this.score;
}
}

运行程序,我们会得到ClassCastException,即无法将Person转型为Comparable。我们修改代码,让Person实现Comparable<T>接口:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
// sort
import java.util.Arrays;

public class Main {
public static void main(String[] args) {
Person[] ps = new Person[] {
new Person("Bob", 61),
new Person("Alice", 88),
new Person("Lily", 75),
};
Arrays.sort(ps);
System.out.println(Arrays.toString(ps));
}
}

class Person implements Comparable<Person> {
String name;
int score;
Person(String name, int score) {
this.name = name;
this.score = score;
}
public int compareTo(Person other) {
return this.name.compareTo(other.name);
}
public String toString() {
return this.name + "," + this.score;
}
}

运行上述代码,可以正确实现按name进行排序。

也可以修改比较逻辑,例如,按score从高到低排序。请自行修改测试。

小结

使用泛型时,把泛型参数<T>替换为需要的class类型,例如:ArrayList<String>ArrayList<Number>等;

可以省略编译器能自动推断出的类型,例如:List<String> list = new ArrayList<>();

不指定泛型参数类型时,编译器会给出警告,且只能将<T>视为Object类型;

可以在接口中定义泛型类型,实现此接口的类必须实现正确的泛型类型。

编写泛型类比普通类要复杂。通常来说,泛型类一般用在集合类中,例如ArrayList<T>,我们很少需要编写泛型类。

如果我们确实需要编写一个泛型类,那么,应该如何编写它?

可以按照以下步骤来编写一个泛型类。

首先,按照某种类型,例如:String,来编写类:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
public class Pair {
private String first;
private String last;
public Pair(String first, String last) {
this.first = first;
this.last = last;
}
public String getFirst() {
return first;
}
public String getLast() {
return last;
}
}

然后,标记所有的特定类型,这里是String,把特定类型String替换为T,并申明<T>

1
2
3
4
5
6
7
8
9
10
11
12
13
14
public class Pair<T> {
private T first;
private T last;
public Pair(T first, T last) {
this.first = first;
this.last = last;
}
public T getFirst() {
return first;
}
public T getLast() {
return last;
}
}

熟练后即可直接从T开始编写。

静态方法

编写泛型类时,要特别注意,泛型类型<T>不能用于静态方法。例如:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
public class Pair<T> {
private T first;
private T last;
public Pair(T first, T last) {
this.first = first;
this.last = last;
}
public T getFirst() { ... }
public T getLast() { ... }

// 对静态方法使用<T>:
public static Pair<T> create(T first, T last) {
return new Pair<T>(first, last);
}
}

上述代码会导致编译错误,我们无法在静态方法create()的方法参数和返回类型上使用泛型类型T

有些同学在网上搜索发现,可以在static修饰符后面加一个<T>,编译就能通过:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
public class Pair<T> {
private T first;
private T last;
public Pair(T first, T last) {
this.first = first;
this.last = last;
}
public T getFirst() { ... }
public T getLast() { ... }

// 可以编译通过:
public static <T> Pair<T> create(T first, T last) {
return new Pair<T>(first, last);
}
}

但实际上,这个<T>Pair<T>类型的<T>已经没有任何关系了。

对于静态方法,我们可以单独改写为“泛型”方法,只需要使用另一个类型即可。对于上面的create()静态方法,我们应该把它改为另一种泛型类型,例如,<K>

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
public class Pair<T> {
private T first;
private T last;
public Pair(T first, T last) {
this.first = first;
this.last = last;
}
public T getFirst() { ... }
public T getLast() { ... }

// 静态泛型方法应该使用其他类型区分:
public static <K> Pair<K> create(K first, K last) {
return new Pair<K>(first, last);
}
}

这样才能清楚地将静态方法的泛型类型和实例类型的泛型类型区分开。

多个泛型类型

泛型还可以定义多种类型。例如,我们希望Pair不总是存储两个类型一样的对象,就可以使用类型<T, K>

1
2
3
4
5
6
7
8
9
10
public class Pair<T, K> {
private T first;
private K last;
public Pair(T first, K last) {
this.first = first;
this.last = last;
}
public T getFirst() { ... }
public K getLast() { ... }
}

使用的时候,需要指出两种类型:

1
Pair<String, Integer> p = new Pair<>("test", 123);

Java标准库的Map<K, V>就是使用两种泛型类型的例子。它对Key使用一种类型,对Value使用另一种类型。

小结

编写泛型时,需要定义泛型类型<T>

静态方法不能引用泛型类型<T>,必须定义其他类型(例如<K>)来实现静态泛型方法;

泛型可以同时定义多种类型,例如Map<K, V>

泛型是一种类似”模板代码“的技术,不同语言的泛型实现方式不一定相同。

Java语言的泛型实现方式是擦拭法(Type Erasure)。

所谓擦拭法是指,虚拟机对泛型其实一无所知,所有的工作都是编译器做的。

例如,我们编写了一个泛型类Pair<T>,这是编译器看到的代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
public class Pair<T> {
private T first;
private T last;
public Pair(T first, T last) {
this.first = first;
this.last = last;
}
public T getFirst() {
return first;
}
public T getLast() {
return last;
}
}

而虚拟机根本不知道泛型。这是虚拟机执行的代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
public class Pair {
private Object first;
private Object last;
public Pair(Object first, Object last) {
this.first = first;
this.last = last;
}
public Object getFirst() {
return first;
}
public Object getLast() {
return last;
}
}

因此,Java使用擦拭法实现泛型,导致了:

  • 编译器把类型<T>视为Object
  • 编译器根据<T>实现安全的强制转型。

使用泛型的时候,我们编写的代码也是编译器看到的代码:

1
2
3
Pair<String> p = new Pair<>("Hello", "world");
String first = p.getFirst();
String last = p.getLast();

而虚拟机执行的代码并没有泛型:

1
2
3
Pair p = new Pair("Hello", "world");
String first = (String) p.getFirst();
String last = (String) p.getLast();

所以,Java的泛型是由编译器在编译时实行的,编译器内部永远把所有类型T视为Object处理,但是,在需要转型的时候,编译器会根据T的类型自动为我们实行安全地强制转型。

了解了Java泛型的实现方式——擦拭法,我们就知道了Java泛型的局限:

局限一:<T>不能是基本类型,例如int,因为实际类型是ObjectObject类型无法持有基本类型:

1
Pair<int> p = new Pair<>(1, 2); // compile error!

局限二:无法取得带泛型的Class。观察以下代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
public class Main {
public static void main(String[] args) {
Pair<String> p1 = new Pair<>("Hello", "world");
Pair<Integer> p2 = new Pair<>(123, 456);
Class c1 = p1.getClass();
Class c2 = p2.getClass();
System.out.println(c1==c2); // true
System.out.println(c1==Pair.class); // true
}
}

class Pair<T> {
private T first;
private T last;
public Pair(T first, T last) {
this.first = first;
this.last = last;
}
public T getFirst() {
return first;
}
public T getLast() {
return last;
}
}

因为TObject,我们对Pair<String>Pair<Integer>类型获取Class时,获取到的是同一个Class,也就是Pair类的Class

换句话说,所有泛型实例,无论T的类型是什么,getClass()返回同一个Class实例,因为编译后它们全部都是Pair<Object>

局限三:无法判断带泛型的类型:

1
2
3
4
Pair<Integer> p = new Pair<>(123, 456);
// Compile error:
if (p instanceof Pair<String>) {
}

原因和前面一样,并不存在Pair<String>.class,而是只有唯一的Pair.class

局限四:不能实例化T类型:

1
2
3
4
5
6
7
8
9
public class Pair<T> {
private T first;
private T last;
public Pair() {
// Compile error:
first = new T();
last = new T();
}
}

上述代码无法通过编译,因为构造方法的两行语句:

1
2
first = new T();
last = new T();

擦拭后实际上变成了:

1
2
first = new Object();
last = new Object();

这样一来,创建new Pair<String>()和创建new Pair<Integer>()就全部成了Object,显然编译器要阻止这种类型不对的代码。

要实例化T类型,我们必须借助额外的Class<T>参数:

1
2
3
4
5
6
7
8
public class Pair<T> {
private T first;
private T last;
public Pair(Class<T> clazz) {
first = clazz.newInstance();
last = clazz.newInstance();
}
}

上述代码借助Class<T>参数并通过反射来实例化T类型,使用的时候,也必须传入Class<T>。例如:

1
Pair<String> pair = new Pair<>(String.class);

因为传入了Class<String>的实例,所以我们借助String.class就可以实例化String类型。

不恰当的覆写方法

有些时候,一个看似正确定义的方法会无法通过编译。例如:

1
2
3
4
5
public class Pair<T> {
public boolean equals(T t) {
return this == t;
}
}

这是因为,定义的equals(T t)方法实际上会被擦拭成equals(Object t),而这个方法是继承自Object的,编译器会阻止一个实际上会变成覆写的泛型方法定义。

换个方法名,避开与Object.equals(Object)的冲突就可以成功编译:

1
2
3
4
5
public class Pair<T> {
public boolean same(T t) {
return this == t;
}
}

泛型继承

一个类可以继承自一个泛型类。例如:父类的类型是Pair<Integer>,子类的类型是IntPair,可以这么继承:

1
2
public class IntPair extends Pair<Integer> {
}

使用的时候,因为子类IntPair并没有泛型类型,所以,正常使用即可:

1
IntPair ip = new IntPair(1, 2);

前面讲了,我们无法获取Pair<T>T类型,即给定一个变量Pair<Integer> p,无法从p中获取到Integer类型。

但是,在父类是泛型类型的情况下,编译器就必须把类型T(对IntPair来说,也就是Integer类型)保存到子类的class文件中,不然编译器就不知道IntPair只能存取Integer这种类型。

在继承了泛型类型的情况下,子类可以获取父类的泛型类型。例如:IntPair可以获取到父类的泛型类型Integer。获取父类的泛型类型代码比较复杂:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;

public class Main {
public static void main(String[] args) {
Class<IntPair> clazz = IntPair.class;
Type t = clazz.getGenericSuperclass();
if (t instanceof ParameterizedType) {
ParameterizedType pt = (ParameterizedType) t;
Type[] types = pt.getActualTypeArguments(); // 可能有多个泛型类型
Type firstType = types[0]; // 取第一个泛型类型
Class<?> typeClass = (Class<?>) firstType;
System.out.println(typeClass); // Integer
}
}
}

class Pair<T> {
private T first;
private T last;
public Pair(T first, T last) {
this.first = first;
this.last = last;
}
public T getFirst() {
return first;
}
public T getLast() {
return last;
}
}

class IntPair extends Pair<Integer> {
public IntPair(Integer first, Integer last) {
super(first, last);
}
}

因为Java引入了泛型,所以,只用Class来标识类型已经不够了。实际上,Java的类型系统结构如下:

1
2
3
4
5
6
7
8
9
10
                      ┌────┐
│Type│
└────┘


┌────────────┬────────┴─────────┬───────────────┐
│ │ │ │
┌─────┐┌─────────────────┐┌────────────────┐┌────────────┐
│Class││ParameterizedType││GenericArrayType││WildcardType│
└─────┘└─────────────────┘└────────────────┘└────────────┘

小结

Java的泛型是采用擦拭法实现的;

擦拭法决定了泛型<T>

  • 不能是基本类型,例如:int
  • 不能获取带泛型类型的Class,例如:Pair<String>.class
  • 不能判断带泛型类型的类型,例如:x instanceof Pair<String>
  • 不能实例化T类型,例如:new T()

泛型方法要防止重复定义方法,例如:public boolean equals(T obj)

子类可以获取父类的泛型类型<T>

我们前面已经讲到了泛型的继承关系:Pair<Integer>不是Pair<Number>的子类。

假设我们定义了Pair<T>

1
public class Pair<T> { ... }

然后,我们又针对Pair<Number>类型写了一个静态方法,它接收的参数类型是Pair<Number>

1
2
3
4
5
6
7
public class PairHelper {
static int add(Pair<Number> p) {
Number first = p.getFirst();
Number last = p.getLast();
return first.intValue() + last.intValue();
}
}

上述代码是可以正常编译的。使用的时候,我们传入:

1
int sum = PairHelper.add(new Pair<Number>(1, 2));

注意:传入的类型是Pair<Number>,实际参数类型是(Integer, Integer)

既然实际参数是Integer类型,试试传入Pair<Integer>

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
public class Main {
public static void main(String[] args) {
Pair<Integer> p = new Pair<>(123, 456);
int n = add(p);
System.out.println(n);
}

static int add(Pair<Number> p) {
Number first = p.getFirst();
Number last = p.getLast();
return first.intValue() + last.intValue();
}
}

class Pair<T> {
private T first;
private T last;
public Pair(T first, T last) {
this.first = first;
this.last = last;
}
public T getFirst() {
return first;
}
public T getLast() {
return last;
}
}

直接运行,会得到一个编译错误:

1
incompatible types: Pair<Integer> cannot be converted to Pair<Number>

原因很明显,因为Pair<Integer>不是Pair<Number>的子类,因此,add(Pair<Number>)不接受参数类型Pair<Integer>

但是从add()方法的代码可知,传入Pair<Integer>是完全符合内部代码的类型规范,因为语句:

1
2
Number first = p.getFirst();
Number last = p.getLast();

实际类型是Integer,引用类型是Number,没有问题。问题在于方法参数类型定死了只能传入Pair<Number>

有没有办法使得方法参数接受Pair<Integer>?办法是有的,这就是使用Pair<? extends Number>使得方法接收所有泛型类型为NumberNumber子类的Pair类型。我们把代码改写如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
public class Main {
public static void main(String[] args) {
Pair<Integer> p = new Pair<>(123, 456);
int n = add(p);
System.out.println(n);
}

static int add(Pair<? extends Number> p) {
Number first = p.getFirst();
Number last = p.getLast();
return first.intValue() + last.intValue();
}
}

class Pair<T> {
private T first;
private T last;
public Pair(T first, T last) {
this.first = first;
this.last = last;
}
public T getFirst() {
return first;
}
public T getLast() {
return last;
}
}

这样一来,给方法传入Pair<Integer>类型时,它符合参数Pair<? extends Number>类型。这种使用<? extends Number>的泛型定义称之为上界通配符(Upper Bounds Wildcards),即把泛型类型T的上界限定在Number了。

除了可以传入Pair<Integer>类型,我们还可以传入Pair<Double>类型,Pair<BigDecimal>类型等等,因为DoubleBigDecimal都是Number的子类。

如果我们考察对Pair<? extends Number>类型调用getFirst()方法,实际的方法签名变成了:

1
<? extends Number> getFirst();

即返回值是NumberNumber的子类,因此,可以安全赋值给Number类型的变量:

1
Number x = p.getFirst();

然后,我们不可预测实际类型就是Integer,例如,下面的代码是无法通过编译的:

1
Integer x = p.getFirst();

这是因为实际的返回类型可能是Integer,也可能是Double或者其他类型,编译器只能确定类型一定是Number的子类(包括Number类型本身),但具体类型无法确定。

我们再来考察一下Pair<T>set方法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
public class Main {
public static void main(String[] args) {
Pair<Integer> p = new Pair<>(123, 456);
int n = add(p);
System.out.println(n);
}

static int add(Pair<? extends Number> p) {
Number first = p.getFirst();
Number last = p.getLast();
p.setFirst(new Integer(first.intValue() + 100));
p.setLast(new Integer(last.intValue() + 100));
return p.getFirst().intValue() + p.getLast().intValue();
}
}

class Pair<T> {
private T first;
private T last;

public Pair(T first, T last) {
this.first = first;
this.last = last;
}

public T getFirst() {
return first;
}
public T getLast() {
return last;
}
public void setFirst(T first) {
this.first = first;
}
public void setLast(T last) {
this.last = last;
}
}

不出意外,我们会得到一个编译错误:

1
2
3
incompatible types: Integer cannot be converted to CAP#1
where CAP#1 is a fresh type-variable:
CAP#1 extends Number from capture of ? extends Number

编译错误发生在p.setFirst()传入的参数是Integer类型。有些童鞋会问了,既然p的定义是Pair<? extends Number>,那么setFirst(? extends Number)为什么不能传入Integer

原因还在于擦拭法。如果我们传入的pPair<Double>,显然它满足参数定义Pair<? extends Number>,然而,Pair<Double>setFirst()显然无法接受Integer类型。

这就是<? extends Number>通配符的一个重要限制:方法参数签名setFirst(? extends Number)无法传递任何Number的子类型给setFirst(? extends Number)

这里唯一的例外是可以给方法参数传入null

1
2
p.setFirst(null); // ok, 但是后面会抛出NullPointerException
p.getFirst().intValue(); // NullPointerException

extends通配符的作用

如果我们考察Java标准库的java.util.List<T>接口,它实现的是一个类似“可变数组”的列表,主要功能包括:

1
2
3
4
5
6
public interface List<T> {
int size(); // 获取个数
T get(int index); // 根据索引获取指定元素
void add(T t); // 添加一个新元素
void remove(T t); // 删除一个已有元素
}

现在,让我们定义一个方法来处理列表的每个元素:

1
2
3
4
5
6
7
8
int sumOfList(List<? extends Integer> list) {
int sum = 0;
for (int i=0; i<list.size(); i++) {
Integer n = list.get(i);
sum = sum + n;
}
return sum;
}

为什么我们定义的方法参数类型是List<? extends Integer>而不是List<Integer>?从方法内部代码看,传入List<? extends Integer>或者List<Integer>是完全一样的,但是,注意到List<? extends Integer>的限制:

  • 允许调用get()方法获取Integer的引用;
  • 不允许调用set(? extends Integer)方法并传入任何Integer的引用(null除外)。

因此,方法参数类型List<? extends Integer>表明了该方法内部只会读取List的元素,不会修改List的元素(因为无法调用add(? extends Integer)remove(? extends Integer)这些方法。换句话说,这是一个对参数List<? extends Integer>进行只读的方法(恶意调用set(null)除外)。

使用extends限定T类型

在定义泛型类型Pair<T>的时候,也可以使用extends通配符来限定T的类型:

1
public class Pair<T extends Number> { ... }

现在,我们只能定义:

1
2
3
Pair<Number> p1 = null;
Pair<Integer> p2 = new Pair<>(1, 2);
Pair<Double> p3 = null;

因为NumberIntegerDouble都符合<T extends Number>

Number类型将无法通过编译:

1
2
Pair<String> p1 = null; // compile error!
Pair<Object> p2 = null; // compile error!

因为StringObject都不符合<T extends Number>,因为它们不是Number类型或Number的子类。

小结

使用类似<? extends Number>通配符作为方法参数时表示:

  • 方法内部可以调用获取Number引用的方法,例如:Number n = obj.getFirst();
  • 方法内部无法调用传入Number引用的方法(null除外),例如:obj.setFirst(Number n);

即一句话总结:使用extends通配符表示可以读,不能写。

使用类似<T extends Number>定义泛型类时表示:

  • 泛型类型限定为Number以及Number的子类。

我们前面已经讲到了泛型的继承关系:Pair<Integer>不是Pair<Number>的子类。

考察下面的set方法:

1
2
3
4
void set(Pair<Integer> p, Integer first, Integer last) {
p.setFirst(first);
p.setLast(last);
}

传入Pair<Integer>是允许的,但是传入Pair<Number>是不允许的。

extends通配符相反,这次,我们希望接受Pair<Integer>类型,以及Pair<Number>Pair<Object>,因为NumberObjectInteger的父类,setFirst(Number)setFirst(Object)实际上允许接受Integer类型。

我们使用super通配符来改写这个方法:

1
2
3
4
void set(Pair<? super Integer> p, Integer first, Integer last) {
p.setFirst(first);
p.setLast(last);
}

注意到Pair<? super Integer>表示,方法参数接受所有泛型类型为IntegerInteger父类的Pair类型。

下面的代码可以被正常编译:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
public class Main {
public static void main(String[] args) {
Pair<Number> p1 = new Pair<>(12.3, 4.56);
Pair<Integer> p2 = new Pair<>(123, 456);
setSame(p1, 100);
setSame(p2, 200);
System.out.println(p1.getFirst() + ", " + p1.getLast());
System.out.println(p2.getFirst() + ", " + p2.getLast());
}

static void setSame(Pair<? super Integer> p, Integer n) {
p.setFirst(n);
p.setLast(n);
}
}

class Pair<T> {
private T first;
private T last;

public Pair(T first, T last) {
this.first = first;
this.last = last;
}

public T getFirst() {
return first;
}
public T getLast() {
return last;
}
public void setFirst(T first) {
this.first = first;
}
public void setLast(T last) {
this.last = last;
}
}

考察Pair<? super Integer>setFirst()方法,它的方法签名实际上是:

1
void setFirst(? super Integer);

因此,可以安全地传入Integer类型。

再考察Pair<? super Integer>getFirst()方法,它的方法签名实际上是:

1
? super Integer getFirst();

这里注意到我们无法使用Integer类型来接收getFirst()的返回值,即下面的语句将无法通过编译:

1
Integer x = p.getFirst();

因为如果传入的实际类型是Pair<Number>,编译器无法将Number类型转型为Integer

注意:虽然Number是一个抽象类,我们无法直接实例化它。但是,即便Number不是抽象类,这里仍然无法通过编译。此外,传入Pair<Object>类型时,编译器也无法将Object类型转型为Integer

唯一可以接收getFirst()方法返回值的是Object类型:

1
Object obj = p.getFirst();

因此,使用<? super Integer>通配符表示:

  • 允许调用set(? super Integer)方法传入Integer的引用;
  • 不允许调用get()方法获得Integer的引用。

唯一例外是可以获取Object的引用:Object o = p.getFirst()

换句话说,使用<? super Integer>通配符作为方法参数,表示方法内部代码对于参数只能写,不能读。

对比extends和super通配符

我们再回顾一下extends通配符。作为方法参数,<? extends T>类型和<? super T>类型的区别在于:

  • <? extends T>允许调用读方法T get()获取T的引用,但不允许调用写方法set(T)传入T的引用(传入null除外);
  • <? super T>允许调用写方法set(T)传入T的引用,但不允许调用读方法T get()获取T的引用(获取Object除外)。

一个是允许读不允许写,另一个是允许写不允许读。

先记住上面的结论,我们来看Java标准库的Collections类定义的copy()方法:

1
2
3
4
5
6
7
8
9
public class Collections {
// 把src的每个元素复制到dest中:
public static <T> void copy(List<? super T> dest, List<? extends T> src) {
for (int i=0; i<src.size(); i++) {
T t = src.get(i);
dest.add(t);
}
}
}

它的作用是把一个List的每个元素依次添加到另一个List中。它的第一个参数是List<? super T>,表示目标List,第二个参数List<? extends T>,表示要复制的List。我们可以简单地用for循环实现复制。在for循环中,我们可以看到,对于类型<? extends T>的变量src,我们可以安全地获取类型T的引用,而对于类型<? super T>的变量dest,我们可以安全地传入T的引用。

这个copy()方法的定义就完美地展示了extendssuper的意图:

  • copy()方法内部不会读取dest,因为不能调用dest.get()来获取T的引用;
  • copy()方法内部也不会修改src,因为不能调用src.add(T)

这是由编译器检查来实现的。如果在方法代码中意外修改了src,或者意外读取了dest,就会导致一个编译错误:

1
2
3
4
5
6
7
8
public class Collections {
// 把src的每个元素复制到dest中:
public static <T> void copy(List<? super T> dest, List<? extends T> src) {
...
T t = dest.get(0); // compile error!
src.add(t); // compile error!
}
}

这个copy()方法的另一个好处是可以安全地把一个List<Integer>添加到List<Number>,但是无法反过来添加:

1
2
3
4
5
6
7
// copy List<Integer> to List<Number> ok:
List<Number> numList = ...;
List<Integer> intList = ...;
Collections.copy(numList, intList);

// ERROR: cannot copy List<Number> to List<Integer>:
Collections.copy(intList, numList);

而这些都是通过superextends通配符,并由编译器强制检查来实现的。

PECS原则

何时使用extends,何时使用super?为了便于记忆,我们可以用PECS原则:Producer Extends Consumer Super。

即:如果需要返回T,它是生产者(Producer),要使用extends通配符;如果需要写入T,它是消费者(Consumer),要使用super通配符。

还是以Collectionscopy()方法为例:

1
2
3
4
5
6
7
8
public class Collections {
public static <T> void copy(List<? super T> dest, List<? extends T> src) {
for (int i=0; i<src.size(); i++) {
T t = src.get(i); // src是producer
dest.add(t); // dest是consumer
}
}
}

需要返回Tsrc是生产者,因此声明为List<? extends T>,需要写入Tdest是消费者,因此声明为List<? super T>

无限定通配符

我们已经讨论了<? extends T><? super T>作为方法参数的作用。实际上,Java的泛型还允许使用无限定通配符(Unbounded Wildcard Type),即只定义一个?

1
2
void sample(Pair<?> p) {
}

因为<?>通配符既没有extends,也没有super,因此:

  • 不允许调用set(T)方法并传入引用(null除外);
  • 不允许调用T get()方法并获取T引用(只能获取Object引用)。

换句话说,既不能读,也不能写,那只能做一些null判断:

1
2
3
static boolean isNull(Pair<?> p) {
return p.getFirst() == null || p.getLast() == null;
}

大多数情况下,可以引入泛型参数<T>消除<?>通配符:

1
2
3
static <T> boolean isNull(Pair<T> p) {
return p.getFirst() == null || p.getLast() == null;
}

<?>通配符有一个独特的特点,就是:Pair<?>是所有Pair<T>的超类:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
public class Main {
public static void main(String[] args) {
Pair<Integer> p = new Pair<>(123, 456);
Pair<?> p2 = p; // 安全地向上转型
System.out.println(p2.getFirst() + ", " + p2.getLast());
}
}

class Pair<T> {
private T first;
private T last;

public Pair(T first, T last) {
this.first = first;
this.last = last;
}

public T getFirst() {
return first;
}
public T getLast() {
return last;
}
public void setFirst(T first) {
this.first = first;
}
public void setLast(T last) {
this.last = last;
}
}

上述代码是可以正常编译运行的,因为Pair<Integer>Pair<?>的子类,可以安全地向上转型。

小结

使用类似<? super Integer>通配符作为方法参数时表示:

  • 方法内部可以调用传入Integer引用的方法,例如:obj.setFirst(Integer n);
  • 方法内部无法调用获取Integer引用的方法(Object除外),例如:Integer n = obj.getFirst();

即使用super通配符表示只能写不能读。

使用extendssuper通配符要遵循PECS原则。

无限定通配符<?>很少使用,可以用<T>替换,同时它是所有<T>类型的超类。

Java的部分反射API也是泛型。例如:Class<T>就是泛型:

1
2
3
4
5
6
7
// compile warning:
Class clazz = String.class;
String str = (String) clazz.newInstance();

// no warning:
Class<String> clazz = String.class;
String str = clazz.newInstance();

调用ClassgetSuperclass()方法返回的Class类型是Class<? super T>

1
Class<? super String> sup = String.class.getSuperclass();

构造方法Constructor<T>也是泛型:

1
2
3
Class<Integer> clazz = Integer.class;
Constructor<Integer> cons = clazz.getConstructor(int.class);
Integer i = cons.newInstance(123);

我们可以声明带泛型的数组,但不能用new操作符创建带泛型的数组:

1
2
Pair<String>[] ps = null; // ok
Pair<String>[] ps = new Pair<String>[2]; // compile error!

必须通过强制转型实现带泛型的数组:

1
2
@SuppressWarnings("unchecked")
Pair<String>[] ps = (Pair<String>[]) new Pair[2];

使用泛型数组要特别小心,因为数组实际上在运行期没有泛型,编译器可以强制检查变量ps,因为它的类型是泛型数组。但是,编译器不会检查变量arr,因为它不是泛型数组。因为这两个变量实际上指向同一个数组,所以,操作arr可能导致从ps获取元素时报错,例如,以下代码演示了不安全地使用带泛型的数组:

1
2
3
4
5
6
7
8
9
Pair[] arr = new Pair[2];
Pair<String>[] ps = (Pair<String>[]) arr;

ps[0] = new Pair<String>("a", "b");
arr[1] = new Pair<Integer>(1, 2);

// ClassCastException:
Pair<String> p = ps[1];
String s = p.getFirst();

要安全地使用泛型数组,必须扔掉arr的引用:

1
2
@SuppressWarnings("unchecked")
Pair<String>[] ps = (Pair<String>[]) new Pair[2];

上面的代码中,由于拿不到原始数组的引用,就只能对泛型数组ps进行操作,这种操作就是安全的。

带泛型的数组实际上是编译器的类型擦除:

1
2
3
4
5
6
7
Pair[] arr = new Pair[2];
Pair<String>[] ps = (Pair<String>[]) arr;

System.out.println(ps.getClass() == Pair[].class); // true

String s1 = (String) arr[0].getFirst();
String s2 = ps[0].getFirst();

所以我们不能直接创建泛型数组T[],因为擦拭后代码变为Object[]

1
2
3
4
5
6
// compile error:
public class Abc<T> {
T[] createArray() {
return new T[5];
}
}

必须借助Class<T>来创建泛型数组:

1
2
3
T[] createArray(Class<T> cls) {
return (T[]) Array.newInstance(cls, 5);
}

我们还可以利用可变参数创建泛型数组T[]

1
2
3
4
5
6
7
8
9
public class ArrayHelper {
@SafeVarargs
static <T> T[] asArray(T... objs) {
return objs;
}
}

String[] ss = ArrayHelper.asArray("a", "b", "c");
Integer[] ns = ArrayHelper.asArray(1, 2, 3);

谨慎使用泛型可变参数

在上面的例子中,我们看到,通过:

1
2
3
static <T> T[] asArray(T... objs) {
return objs;
}

似乎可以安全地创建一个泛型数组。但实际上,这种方法非常危险。以下代码来自《Effective Java》的示例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import java.util.Arrays;

public class Main {
public static void main(String[] args) {
String[] arr = asArray("one", "two", "three");
System.out.println(Arrays.toString(arr));
// ClassCastException:
String[] firstTwo = pickTwo("one", "two", "three");
System.out.println(Arrays.toString(firstTwo));
}

static <K> K[] pickTwo(K k1, K k2, K k3) {
return asArray(k1, k2);
}

static <T> T[] asArray(T... objs) {
return objs;
}
}

直接调用asArray(T...)似乎没有问题,但是在另一个方法中,我们返回一个泛型数组就会产生ClassCastException,原因还是因为擦拭法,在pickTwo()方法内部,编译器无法检测K[]的正确类型,因此返回了Object[]

如果仔细观察,可以发现编译器对所有可变泛型参数都会发出警告,除非确认完全没有问题,才可以用@SafeVarargs消除警告。

注意

如果在方法内部创建了泛型数组,最好不要将它返回给外部使用。

更详细的解释请参考《Effective Java》“Item 32: Combine generics and varargs judiciously”。

小结

部分反射API是泛型,例如:Class<T>Constructor<T>

可以声明带泛型的数组,但不能直接创建带泛型的数组,必须强制转型;

可以通过Array.newInstance(Class<T>, int)创建T[]数组,需要强制转型;

同时使用泛型和可变参数时需要特别小心。

留言與分享

作者的圖片

Kein Chan

這是獨立全棧工程師Kein Chan的技術博客
分享一些技術教程,命令備忘(cheat-sheet)等


全棧工程師
資深技術顧問
數據科學家
Hit廣島觀光大使


Tokyo/Macau