spring bean初始化异步执行
目录
- 问题引入:很多bean初始化很慢
- 统计bean初始化方法耗时:自定义BeanPostProcessor
- 自定义beanFactory
- 附 bean的生命周期
- 附 `DefaultListableBeanFactory`类图
- 附 AbstractAutowireCapableBeanFactory的invokeInitMethods
- 自定义BeanFactoryPostProcessor给bean打标
- 要异步初始化的bean 例子如下
- AnnotationConfigApplicationContext 测试
- 测试
问题引入:很多bean初始化很慢
考虑如下简单的程序
package org.example;import org.springframework.context.annotation.AnnotationConfigApplicationContext;public class Main {public static void main(String[] args) {AnnotationConfigApplicationContext applicationContext= new AnnotationConfigApplicationContext();applicationContext.register(Config.class);long startTime = System.currentTimeMillis();applicationContext.refresh();long cost = System.currentTimeMillis() - startTime;System.out.println(String.format("applicationContext refresh cost:%d s", cost / 1000));A a = (A)applicationContext.getBean("a");a.sayHello();applicationContext.stop();}
}
- A,B两个bean 基本定义如下,其初始化方法可能很耗时
package org.example;import org.springframework.stereotype.Component;
import org.springframework.stereotype.Service;import javax.annotation.PostConstruct;
import java.util.concurrent.TimeUnit;@Service
public class A {@PostConstructpublic void init(){try{TimeUnit.SECONDS.sleep(2);}catch (Exception e){}System.out.println("A.init success");}public void sayHello(){System.out.println("A.sayHello");}
}
如下 applicationContext refresh要5秒多
统计bean初始化方法耗时:自定义BeanPostProcessor
即postProcessBeforeInitialization记录bean的开始时间,
postProcessAfterInitialization记录bean初始化完成时间,然后就能得到bean初始化方法耗时。
package org.example;import org.springframework.beans.BeansException;
import org.springframework.beans.factory.config.BeanPostProcessor;
import org.springframework.context.ApplicationEvent;
import org.springframework.context.ApplicationListener;
import org.springframework.context.event.ContextRefreshedEvent;
import org.springframework.stereotype.Component;import java.time.Instant;
import java.time.LocalDateTime;
import java.time.ZoneId;
import java.time.format.DateTimeFormatter;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;import static java.util.Collections.reverseOrder;@Component
public class BeanInitMethodCostTimeBeanPostProcessor implements BeanPostProcessor, ApplicationListener<ApplicationEvent> {private static final AtomicInteger ATOMIC_INTEGER = new AtomicInteger(0);private Map<String, Long> startTime = new HashMap<>(1024);private List<Initialization> costTime = new ArrayList<>(1024);@Overridepublic Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException {long start = System.currentTimeMillis();startTime.put(beanName, start);return bean;}@Overridepublic Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException {if (costTime.stream().anyMatch(it -> it.beanName.equals(beanName))) {return bean;}long end = System.currentTimeMillis();Long start = startTime.get(beanName);if (start != null) {costTime.add(Initialization.parseInitialization(beanName, start, end));}return bean;}@Overridepublic void onApplicationEvent(ApplicationEvent event) {if (event instanceof ContextRefreshedEvent) {costTime.sort(reverseOrder());for (Initialization initialization : costTime) {System.out.println(initialization.toString());}startTime.clear();costTime.clear();}}private static class Initialization implements Comparable<Initialization> {private int serialNumber;private String beanName;private long costTime;private long start;private long end;public static Initialization parseInitialization(String beanName, long start, long end) {Initialization initialization = new Initialization();initialization.serialNumber = ATOMIC_INTEGER.incrementAndGet();initialization.costTime = end - start;initialization.start = start;initialization.end = end;initialization.beanName = beanName;return initialization;}@Overridepublic String toString() {return "serialNumber: " + serialNumber + ",beanName: " + beanName + ",cost " + costTime + " ms,"+ " start: " + convertTimeToString(start) + ", end:" + convertTimeToString(end);}public static String convertTimeToString(Long time) {DateTimeFormatter ftf = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss");return ftf.format(LocalDateTime.ofInstant(Instant.ofEpochMilli(time), ZoneId.systemDefault()));}@Overridepublic int compareTo(Initialization o) {long res = costTime - o.costTime;return res == 0 ? 0 : (res > 0 ? 1 : -1);}}
}
可以看到如下:a,b 两个bean 初始化耗时很久, applicationContext refresh耗时也主要是由于a,b两个bean初始化导致
自定义beanFactory
继承:DefaultListableBeanFactory, 重写invokeInitMethods方法
package org.example;import org.apache.commons.lang3.tuple.Pair;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.BeanCreationException;
import org.springframework.beans.factory.FactoryBean;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.beans.factory.support.DefaultListableBeanFactory;
import org.springframework.beans.factory.support.RootBeanDefinition;import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;import static java.lang.Boolean.TRUE;public class CustomBeanFactory extends DefaultListableBeanFactory {private static List<Future<Pair<String, Throwable>>> taskList = Collections.synchronizedList(new ArrayList<>());private static ExecutorService asyncInitPoll;private boolean contextFinished = false;/*** 此方法是 接口 InitializingBean 的方法,用于在 依赖注入完成后 执行自定义的初始化逻辑*/private final String afterPropertiesSetMethodName = "afterPropertiesSet";public CustomBeanFactory() {super();int poolSize = 4;asyncInitPoll = new ThreadPoolExecutor(poolSize, poolSize, 0L, TimeUnit.MILLISECONDS, new LinkedBlockingQueue<>());}/*** 确保线程池任务都完成** @return*/public boolean confirmAllAsyncTaskHadSuccessfulInvoked() {if (taskList.size() > 0) {long start = System.currentTimeMillis();try {for (Future<Pair<String, Throwable>> task : taskList) {long s0 = System.currentTimeMillis();Pair<String, Throwable> result = task.get();if (result.getRight() != null) {throw result.getRight();}}} catch (Throwable e) {if (e instanceof BeanCreationException) {throw (BeanCreationException) e;} else {throw new BeanCreationException(e.getMessage(), e);}} finally {}}contextFinished = true;asyncInitPoll.shutdown();return contextFinished;}/*** 重写初始化方法*/@Overrideprotected void invokeInitMethods(final String beanName, final Object bean, final RootBeanDefinition mbd)throws Throwable {if (!canAsyncInit(bean, mbd)) {super.invokeInitMethods(beanName, bean, mbd);return;}// 判断是否实现了 InitializingBeanboolean isInitializingBean = (bean instanceof InitializingBean);final boolean needInvokeAfterPropertiesSetMethod = isInitializingBean && (mbd == null || !mbd.isExternallyManagedInitMethod(afterPropertiesSetMethodName));final String initMethodName = (mbd != null ? mbd.getInitMethodName() : null);/*** initMethod {@link afterPropertiesSetMethodName}*/final boolean needInvokeInitMethod = initMethodName != null && !(isInitializingBean&& afterPropertiesSetMethodName.equals(initMethodName)) &&!mbd.isExternallyManagedInitMethod(initMethodName);if (needInvokeAfterPropertiesSetMethod || needInvokeInitMethod) {asyncInvoke(new BeanInitMethodsInvoker() {@Overridepublic void invoke() throws Throwable {if (needInvokeAfterPropertiesSetMethod) {invokeInitMethod(beanName, bean, afterPropertiesSetMethodName, false);}if (needInvokeInitMethod) {invokeInitMethod(beanName, bean, initMethodName, mbd.isEnforceInitMethod());}}@Overridepublic String getBeanName() {return beanName;}});}}// 反射执行初始化方法private void invokeInitMethod(String beanName, Object bean, String method, boolean enforceInitMethod)throws Throwable {Method initMethod = BeanUtils.findMethod(bean.getClass(), method, null);if (initMethod == null) {if (enforceInitMethod) {throw new NoSuchMethodException("Couldn't find an init method named '" + method +"' on bean with name '" + beanName + "'");}} else {initMethod.setAccessible(true);initMethod.invoke(bean);}}private void asyncInvoke(final BeanInitMethodsInvoker beanInitMethodsInvoker) {taskList.add(asyncInitPoll.submit(() -> {long start = System.currentTimeMillis();try {beanInitMethodsInvoker.invoke();return Pair.of(beanInitMethodsInvoker.getBeanName(), null);} catch (Throwable throwable) {return Pair.of(beanInitMethodsInvoker.getBeanName(), new BeanCreationException(beanInitMethodsInvoker.getBeanName() + ": Async Invocation of init method failed", throwable));} finally {System.out.println("asyncInvokeInitMethod " + beanInitMethodsInvoker.getBeanName() + " cost:"+ (System.currentTimeMillis() - start) + "ms.");}}));}// 有特殊属性则需要异步初始化private boolean canAsyncInit(Object bean, RootBeanDefinition mbd) {if (contextFinished || mbd == null || mbd.isLazyInit() || bean instanceof FactoryBean) {return false;}Object value = mbd.getAttribute(Constant.ASYNC_INIT);return TRUE.equals(value) || "true".equals(value);}private interface BeanInitMethodsInvoker {void invoke() throws Throwable;String getBeanName();}}
使用线程池,将初始化方法加入任务队列,并通过反射的方式执行;
另外加上一个确保所有任务都正确执行的方法
附 bean的生命周期
参考:https://doctording.blog.csdn.net/article/details/145044487
附 DefaultListableBeanFactory
类图
附 AbstractAutowireCapableBeanFactory的invokeInitMethods
protected void invokeInitMethods(String beanName, Object bean, @Nullable RootBeanDefinition mbd) throws Throwable {boolean isInitializingBean = bean instanceof InitializingBean;if (isInitializingBean && (mbd == null || !mbd.isExternallyManagedInitMethod("afterPropertiesSet"))) {if (this.logger.isDebugEnabled()) {this.logger.debug("Invoking afterPropertiesSet() on bean with name '" + beanName + "'");}if (System.getSecurityManager() != null) {try {AccessController.doPrivileged(() -> {((InitializingBean)bean).afterPropertiesSet();return null;}, this.getAccessControlContext());} catch (PrivilegedActionException var6) {throw var6.getException();}} else {((InitializingBean)bean).afterPropertiesSet();}}if (mbd != null && bean.getClass() != NullBean.class) {String initMethodName = mbd.getInitMethodName();if (StringUtils.hasLength(initMethodName) && (!isInitializingBean || !"afterPropertiesSet".equals(initMethodName)) && !mbd.isExternallyManagedInitMethod(initMethodName)) {this.invokeCustomInitMethod(beanName, bean, mbd);}}}
自定义BeanFactoryPostProcessor给bean打标
即个beanDefinition打标属性
package org.example;import org.springframework.beans.BeansException;
import org.springframework.beans.factory.NoSuchBeanDefinitionException;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.BeanFactoryPostProcessor;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.stereotype.Component;import java.util.HashSet;
import java.util.Set;@Component
public class AsyncInitBeanFactoryPostProcessor implements BeanFactoryPostProcessor {private Set<String> asyncInitBeanNames = new HashSet<>();public AsyncInitBeanFactoryPostProcessor() {// 这里可以基于配置或者其它方式asyncInitBeanNames.add("a");asyncInitBeanNames.add("b");System.out.println("asyncInitBeanNames:" + asyncInitBeanNames);}@Overridepublic void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException {// 给bean加上特殊的属性for (String beanName : asyncInitBeanNames) {BeanDefinition beanDefinition = null;try {beanDefinition = beanFactory.getBeanDefinition(beanName);} catch (NoSuchBeanDefinitionException e) {}if (beanDefinition != null) {beanDefinition.setAttribute(Constant.ASYNC_INIT, true);}}}public Set<String> getAsyncInitBeanNames() {return asyncInitBeanNames;}public void setAsyncInitBeanNames(Set<String> asyncInitBeanNames) {this.asyncInitBeanNames = asyncInitBeanNames;}
}
要异步初始化的bean 例子如下
package org.example;import org.springframework.beans.factory.InitializingBean;
import org.springframework.stereotype.Service;import java.util.concurrent.TimeUnit;@Service
public class A implements InitializingBean {public int a;@Overridepublic void afterPropertiesSet() {try {TimeUnit.SECONDS.sleep(2);} catch (Exception e) {}a = 100;System.out.println("A.init success");}public void sayHello() {System.out.println("A.sayHello:" + a);}
}
AnnotationConfigApplicationContext 测试
package org.example;import org.springframework.context.annotation.AnnotationConfigApplicationContext;public class Main {public static void main(String[] args) {
// AnnotationConfigApplicationContext applicationContext = new AnnotationConfigApplicationContext();CustomBeanFactory customBeanFactory = new CustomBeanFactory();AnnotationConfigApplicationContext applicationContext = new AnnotationConfigApplicationContext(customBeanFactory);applicationContext.register(Config.class);long startTime = System.currentTimeMillis();applicationContext.refresh();customBeanFactory.confirmAllAsyncTaskHadSuccessfulInvoked();long cost = System.currentTimeMillis() - startTime;System.out.println(String.format("======= applicationContext refresh cost:%d s", cost / 1000));A a = (A)applicationContext.getBean("a");a.sayHello();B b = (B)applicationContext.getBean("b");b.sayHello();applicationContext.stop();}
}
测试
对比之前,现在启动只需要消耗最大的那个bean的初始化时间了,且初始化也是正确的。