Asm is a framework for writing bytecode, which can be used to deepen the mastery of bytecode instructions.

Java dynamic proxy

Java dynamic proxies are based on interface proxies, so first we have to define a public interface.

Now proxy the user interface, implement the login logic and print the login time

public interface UserService {

    boolean login(String username, String password);
}

Copy the code

The newProxyInstance method needs to pass three parameters. The first parameter is the class loader, the second parameter needs to pass the interface array of the Proxy, and the third parameter is to call the method handler, which is also the interface we need to implement to write the Proxy logic.

Implement InvocationHandler, determine whether the incoming username and password is equal to admin, and print the call method time.

public class UserServiceInvocationHandler implements InvocationHandler {

    @Override
    public Object invoke(Object proxy, Method method, Object[] args) throws InvocationTargetException, IllegalAccessException {
        long start = System.currentTimeMillis();
        System.out.println("invoke:" + proxy.getClass().getSimpleName() + "." + method.getName() + ":" + (System.currentTimeMillis() - start) + "ms");
        return "admin".equals(args[0&&])"admin".equals(args[1]); }}Copy the code

Generating proxy classes

import java.lang.reflect.Proxy;

public class App {
    public static void main(String[] args) {
        UserService userServiceProxy = (UserService) Proxy.newProxyInstance(App.class.getClassLoader(), new Class[]{UserService.class}, new UserServiceInvocationHandler());
        System.out.println(userServiceProxy.getClass());
        System.out.println(userServiceProxy.login("admin"."admin"));
        System.out.println(userServiceProxy.login("admin"."admin1")); }}Copy the code

Call the main method and print the result

Implementation using ASM

First let’s take a look at what the resulting proxy class looks like

import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import proxy.ASMProxy;
import proxy.UserService;

public class $Proxy0 extends ASMProxy implements UserService {
    public static Method _UserService_0 = Class.forName("proxy.UserService").getMethod("login", Class.forName("java.lang.String"), Class.forName("java.lang.String"));

    public $Proxy0(InvocationHandler var1) {
        super(var1);
    }

    public boolean login(String var1, String var2) throws Exception {
        return (Boolean)super.h.invoke(this, _UserService_0, newObject[]{var1, var2}); }}Copy the code

Three main points:

  • InvocationHandler is stored in ASMProxy.
  • The interface Method to be implemented, Method, is held in static fields.
  • Implements the invoke method inside the interface method that actually calls the parent class’s InvocationHandler.

Take a look at the implementation steps

ASMProxy

package proxy;

import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationHandler;
import java.util.concurrent.atomic.AtomicInteger;

public class ASMProxy {
    protected InvocationHandler h;
    // Proxy class name counter
    private static final AtomicInteger PROXY_CNT = new AtomicInteger(0);
    private static final String PROXY_CLASS_NAME_PRE = "$Proxy";

    public ASMProxy(InvocationHandler var1) {
        h = var1;
    }

    public static Object newProxyInstance(ClassLoader loader, Class
       [] interfaces, InvocationHandler h)
            throws Exception {
        // Generate the proxy ClassClass<? > proxyClass = generate(interfaces); Constructor<? > constructor = proxyClass.getConstructor(InvocationHandler.class);return constructor.newInstance(h);
    }

    /** * Generate proxy Class **@param interfaces
     * @return* /
    private staticClass<? > generate(Class<? >[] interfaces)throws ClassNotFoundException {
        String proxyClassName = PROXY_CLASS_NAME_PRE + PROXY_CNT.getAndIncrement();
        byte[] codes = ASMProxyFactory.generateClass(interfaces, proxyClassName);
      

        // Use custom class loaders to load bytecode
        ASMClassLoader asmClassLoader = new ASMClassLoader();
        asmClassLoader.add(proxyClassName, codes);
        returnasmClassLoader.loadClass(proxyClassName); }}Copy the code

One of the main functions of ASMProxy is to act as the parent class that the Proxy class needs to inherit, and then provide the same static method as Proxy, newProxyInstance. NewProxyInstance calls ASMProxyFactory to generate the bytecode binary stream, and then calls the custom classloader to generate the Class. Finally reflection generates an instance of the proxy class and returns the object.

ASMProxyFactory

Now let’s look at the core of how ASMProxyFactory generates bytecode in several steps:

  1. Create the initialization method
  2. Declare a static field
  3. Creating static methods
  4. Implement interface methods
package proxy;

import org.objectweb.asm.ClassWriter;
import org.objectweb.asm.MethodVisitor;
import org.objectweb.asm.Opcodes;
import org.objectweb.asm.Type;

import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.stream.Collectors;

public class ASMProxyFactory {
    private static final Integer DEFAULT_NUM = 1;

    public static byte[] generateClass(Class<? >[] interfaces, String proxyClassName) {// Create a ClassWriter object that automatically calculates stack frames and local variable table sizes
        ClassWriter cw = new ClassWriter(ClassWriter.COMPUTE_FRAMES | ClassWriter.COMPUTE_MAXS);
        // Create the Java version, access flag, class name, parent class, interface
        cw.visit(Opcodes.V1_8, Opcodes.ACC_PUBLIC, proxyClassName, null, Type.getInternalName(ASMProxy.class), getInterfacesName(interfaces));
        / / create < init >
        createInit(cw);
        / / create the static
        addStatic(cw, interfaces);
        / / create < clinit >
        addClinit(cw, interfaces, proxyClassName);
        // Implement interface methods
        addInterfacesImpl(cw, interfaces, proxyClassName);
        cw.visitEnd();
        return cw.toByteArray();
    }
    
    private staticString[] getInterfacesName(Class<? >[] interfaces) { String[] interfacesName =new String[interfaces.length];
        return Arrays.stream(interfaces).map(Type::getInternalName).collect(Collectors.toList()).toArray(interfacesName);
    }
    
     Aload_0 * 1 ALOad_1 * 2 Invokespecial #1 <proxy/ASMProxy. (Ljava/lang/reflect/InvocationHandler;) V> * 5 return * *@param cw
     */
    private static void createInit(ClassWriter cw) {
        MethodVisitor mv = cw.visitMethod(Opcodes.ACC_PUBLIC, "<init>"."(Ljava/lang/reflect/InvocationHandler;) V".null.null);
        mv.visitCode();
        // push this onto the stack
        mv.visitVarInsn(Opcodes.ALOAD, 0);
        // push parameters to the stack
        mv.visitVarInsn(Opcodes.ALOAD, 1);
        // Call the parent class initialization method
        mv.visitMethodInsn(Opcodes.INVOKESPECIAL, Type.getInternalName(ASMProxy.class), "<init>"."(Ljava/lang/reflect/InvocationHandler;) V".false);
        / / return
        mv.visitInsn(Opcodes.RETURN);
        mv.visitMaxs(2.2);
        mv.visitEnd();
    }
    
    /** * create static field **@param cw
     * @param interfaces
     */
    private static void addStatic(ClassWriter cw, Class
       [] interfaces) {
        for(Class<? > anInterface : interfaces) {for (int i = 0; i < anInterface.getMethods().length; i++) {
                String methodName = "_" + anInterface.getSimpleName() + "_" + i;
                cw.visitField(Opcodes.ACC_PUBLIC | Opcodes.ACC_STATIC, methodName, Type.getDescriptor(Method.class), null.null); }}}private static void addClinit(ClassWriter cw, Class
       [] interfaces, String proxyClassName) {
        //_UserService_0 = Class.forName("proxy.UserService").getMethod("login", String.class, String.class);
        MethodVisitor mv = cw.visitMethod(Opcodes.ACC_STATIC, "<clinit>"."()V".null.null);
        mv.visitCode();
        for(Class<? > anInterface : interfaces) {for (int i = 0; i < anInterface.getMethods().length; i++) {
                Method method = anInterface.getMethods()[i];
                String methodName = "_" + anInterface.getSimpleName() + "_" + i;
                mv.visitLdcInsn(anInterface.getName());
                mv.visitMethodInsn(Opcodes.INVOKESTATIC, Type.getInternalName(Class.class), "forName"."(Ljava/lang/String;) Ljava/lang/Class;".false);
                mv.visitLdcInsn(method.getName());
                if (method.getParameterCount() == 0) {
                    mv.visitInsn(Opcodes.ACONST_NULL);
                } else {
                    switch (method.getParameterCount()) {

                        case 1:
                            mv.visitInsn(Opcodes.ICONST_1);
                            break;
                        case 2:
                            mv.visitInsn(Opcodes.ICONST_2);
                            break;
                        case 3:
                            mv.visitInsn(Opcodes.ICONST_3);
                            break;
                        default:
                            mv.visitVarInsn(Opcodes.BIPUSH, method.getParameterCount());
                            break;
                    }
                    mv.visitTypeInsn(Opcodes.ANEWARRAY, Type.getInternalName(Class.class));
                    for (int paramIndex = 0; paramIndex < method.getParameterTypes().length; paramIndex++) { Class<? > parameter = method.getParameterTypes()[paramIndex]; mv.visitInsn(Opcodes.DUP);switch (paramIndex) {
                            case 0:
                                mv.visitInsn(Opcodes.ICONST_0);
                                break;
                            case 1:
                                mv.visitInsn(Opcodes.ICONST_1);
                                break;
                            case 2:
                                mv.visitInsn(Opcodes.ICONST_2);
                                break;
                            case 3:
                                mv.visitInsn(Opcodes.ICONST_3);
                                break;
                            default:
                                mv.visitVarInsn(Opcodes.BIPUSH, paramIndex);
                                break;
                        }
                        mv.visitLdcInsn(parameter.getName());
                        mv.visitMethodInsn(
                                Opcodes.INVOKESTATIC, Type.getInternalName(Class.class),
                                "forName"."(Ljava/lang/String;) Ljava/lang/Class;".false); mv.visitInsn(Opcodes.AASTORE); }}// invokevirtual #13 
      
                mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Class.class), "getMethod"."(Ljava/lang/String; [Ljava/lang/Class;)Ljava/lang/reflect/Method;".false);
                //putstatic #3 
      
                mv.visitFieldInsn(Opcodes.PUTSTATIC, proxyClassName, methodName, Type.getDescriptor(Method.class));
            }
            mv.visitInsn(Opcodes.RETURN);
        }
        mv.visitMaxs(DEFAULT_NUM, DEFAULT_NUM);
        mv.visitEnd();
    }
    
    private static void addInterfacesImpl(ClassWriter cw, Class
       [] interfaces, String proxyClassName) {
        for(Class<? > anInterface : interfaces) {for (int i = 0; i < anInterface.getMethods().length; i++) {
                Method method = anInterface.getMethods()[i];
                String methodName = "_" + anInterface.getSimpleName() + "_" + i;
                MethodVisitor mv = cw.visitMethod(Opcodes.ACC_PUBLIC, method.getName(), Type.getMethodDescriptor(method), null.new String[]{Type.getInternalName(Exception.class)});
                mv.visitCode();
                mv.visitVarInsn(Opcodes.ALOAD, 0);
                mv.visitFieldInsn(Opcodes.GETFIELD, Type.getInternalName(ASMProxy.class), "h"."Ljava/lang/reflect/InvocationHandler;");
                mv.visitVarInsn(Opcodes.ALOAD, 0);
                mv.visitFieldInsn(Opcodes.GETSTATIC, proxyClassName, methodName, Type.getDescriptor(Method.class));
                //
                switch (method.getParameterCount()) {
                    case 0:
                        mv.visitInsn(Opcodes.ICONST_0);
                        break;
                    case 1:
                        mv.visitInsn(Opcodes.ICONST_1);
                        break;
                    case 2:
                        mv.visitInsn(Opcodes.ICONST_2);
                        break;
                    case 3:
                        mv.visitInsn(Opcodes.ICONST_3);
                        break;
                    default:
                        mv.visitVarInsn(Opcodes.BIPUSH, method.getParameterCount());
                        break;
                }
                mv.visitTypeInsn(Opcodes.ANEWARRAY, Type.getInternalName(Object.class));
                // * 12 dup
                // * 13 iconst_0
                // * 14 aload_1
                // * 15 aastore
                for (int paramIndex = 0; paramIndex < method.getParameterCount(); paramIndex++) {
                    mv.visitInsn(Opcodes.DUP);
                    switch (paramIndex) {
                        case 0:
                            mv.visitInsn(Opcodes.ICONST_0);
                            break;
                        case 1:
                            mv.visitInsn(Opcodes.ICONST_1);
                            break;
                        case 2:
                            mv.visitInsn(Opcodes.ICONST_2);
                            break;
                        case 3:
                            mv.visitInsn(Opcodes.ICONST_3);
                            break;
                        default:
                            mv.visitVarInsn(Opcodes.BIPUSH, paramIndex);
                            break;
                    }
                    mv.visitVarInsn(Opcodes.ALOAD, paramIndex + 1);
                    mv.visitInsn(Opcodes.AASTORE);
                }
// * 20 invokeinterface #5 
      
        count 4
      
// * 25 checkcast #6 
      
// * 28 invokevirtual #7 
      
                mv.visitMethodInsn(Opcodes.INVOKEINTERFACE, Type.getInternalName(InvocationHandler.class), "invoke"."(Ljava/lang/Object; Ljava/lang/reflect/Method; [Ljava/lang/Object;)Ljava/lang/Object;".true);
                addReturn(mv, method.getReturnType());
// mv.visitFrame(Opcodes.F_FULL, 0, null, 0, null);mv.visitMaxs(DEFAULT_NUM, DEFAULT_NUM); mv.visitEnd(); }}}// Add method returns
    private static void addReturn(MethodVisitor mv, Class
        returnType) {
        if (returnType.isAssignableFrom(Void.class)) {
            mv.visitInsn(Opcodes.RETURN);
            return;
        }
        if (returnType.isAssignableFrom(boolean.class)) {
            //checkcast #6 <java/lang/Boolean>
            // * 28 invokevirtual #7 
      
            mv.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(Boolean.class));
            mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Boolean.class), "booleanValue"."()Z".false);
            mv.visitInsn(Opcodes.IRETURN);
        } else if (returnType.isAssignableFrom(int.class)) {
            mv.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(Integer.class));
            mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Integer.class), "intValue"."()I".false);
            mv.visitInsn(Opcodes.IRETURN);
        } else if (returnType.isAssignableFrom(long.class)) {
            mv.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(Long.class));
            mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Long.class), "longValue"."()J".false);
            mv.visitInsn(Opcodes.IRETURN);
        } else if (returnType.isAssignableFrom(short.class)) {
            mv.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(Short.class));
            mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Short.class), "shortValue"."()S".false);
            mv.visitInsn(Opcodes.IRETURN);
        } else if (returnType.isAssignableFrom(byte.class)) {
            mv.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(Byte.class));
            mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Byte.class), "byteValue"."()B".false);
            mv.visitInsn(Opcodes.IRETURN);
        } else if (returnType.isAssignableFrom(char.class)) {
            mv.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(Character.class));
            mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Character.class), "charValue"."()C".false);
            mv.visitInsn(Opcodes.IRETURN);
        } else if (returnType.isAssignableFrom(float.class)) {
            mv.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(Float.class));
            mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Float.class), "floatValue"."()F".false);
            mv.visitInsn(Opcodes.IRETURN);
        } else if (returnType.isAssignableFrom(double.class)) {
            mv.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(Double.class));
            mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Double.class), "doubleValue"."()D".false);
            mv.visitInsn(Opcodes.IRETURN);
        } else{ mv.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(returnType)); mv.visitInsn(Opcodes.ARETURN); }}}Copy the code

ASMClassLoader

Custom Class loader, provide add method, add < Class name, bytecode > mapping overrides findClass method, when the Class name can find the corresponding bytecode, call defineClass to generate Class.

package proxy;

import java.util.HashMap;
import java.util.Map;

public class ASMClassLoader extends ClassLoader {
    private final Map<String, byte[]> classMap = new HashMap<>();

    @Override
    protected Class<?> findClass(String name) throws ClassNotFoundException {
        if (classMap.containsKey(name)) {
            byte[] bytes = classMap.get(name);
            classMap.remove(name);
            return defineClass(name, bytes, 0, bytes.length);
        }
        return super.findClass(name);
    }

    public void add(String name, byte[] bytes) {
        classMap.put(name, bytes);
    }
}
Copy the code

App

package proxy;

import java.lang.reflect.Proxy;

public class App {
    public static void main(String[] args) throws Throwable {
        System.out.println("Java dynamic Proxy ===========================");
        UserService userServiceProxy = (UserService) Proxy.newProxyInstance(App.class.getClassLoader(), new Class[]{UserService.class}, new UserServiceInvocationHandler());
        System.out.println(userServiceProxy.getClass());
        System.out.println(userServiceProxy.login("admin"."admin"));
        System.out.println(userServiceProxy.login("admin"."admin1"));


        System.out.println("ASM Dynamic Proxy ===========================");
        UserService userServiceAsm = (UserService) ASMProxy.newProxyInstance(App.class.getClassLoader(), new Class[]{UserService.class}, new UserServiceInvocationHandler());
        System.out.println(userServiceAsm.getClass());
        System.out.println(userServiceAsm.login("admin"."admin"));
        System.out.println(userServiceAsm.login("admin"."admin1")); }}Copy the code

Run App: Prints the results of both proxy methods