Java 在套接字连接中使用交换密钥的 RSA

Java RSA with exchanged keys in socket connection

我正在 Java 中实现一个 RSA 加密套接字连接,为此我使用了两个 class,第一个是连接摘要 class,它代表了真正的套接字连接和第二个是 ConnectionCallback,它是在 Connection class 接收数据时调用的 class。 当连接 class 接收到数据时,使用来自连接端点的共享前 public 密钥对数据进行解密(只能有 1 个连接端点)。

字节数组class:

package connection.data;

public class ByteArray {

    private byte[] bytes;

    public ByteArray(byte[] bytes){
        this.bytes = bytes;
    }

    public ByteArray(){
    }

    public void add(byte[] data) {
        if(this.bytes == null) this.bytes = new byte[0];
        this.bytes = joinArrays(this.bytes, data);
    }

    private byte[] joinArrays(byte[] array1, byte[] array2) {
        byte[] array = new byte[array1.length + array2.length];
        System.arraycopy(array1, 0, array, 0, array1.length);
        System.arraycopy(array2, 0, array, array1.length, array2.length);
        return array;
    }

    public byte[] getBytes(){
        return this.bytes;
    }
}

连接class:

package connection;

import connection.data.ByteArray;
import connection.protocols.ProtectedConnectionProtocol;
import crypto.CryptoUtils;
import crypto.algorithm.asymmetric.rsa.RSAAlgorithm;
import protocol.connection.ConnectionProtocol;
import util.function.Callback;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.InputStream;
import java.io.OutputStream;
import java.security.PublicKey;
import java.util.Base64;

public abstract class Connection implements Runnable {

    private DataInputStream in;
    private DataOutputStream out;
    ConnectionProtocol protocol;
    private Callback callback;
    private boolean isConnected = false;

    public Connection() throws Exception {
        this.protocol = new ProtectedConnectionProtocol(new RSAAlgorithm(1024));
        this.callback = new ConnectionCallback(this);
    }

    public Connection(ConnectionProtocol connectionProtocol, Callback callback) throws Exception {
        this.protocol = connectionProtocol;
        this.callback = callback;
    }

    @Override
    public void run() {
        while(isConnected){
            try {
                ByteArray data = new ByteArray();
                while(this.in.available() > 0){
                    data.add(this.read());
                }
                if(data.getBytes() != null){
                    callback.run(data);
                }
            } catch (Exception e){
                e.printStackTrace();
                break;
            }
        }
    }

    protected void openConnection(InputStream in, OutputStream out) throws Exception{
        this.in = new DataInputStream(in);
        this.out = new DataOutputStream(out);
        this.isConnected = true;
        new Thread(this).start();
        this.write(CryptoUtils.encode(((PublicKey) this.protocol.getPublicKey()).getEncoded()));
    }

    private void write(byte[] data) throws Exception{
        System.out.println(new String(data,"UTF-8"));
        this.out.write(data);
        this.out.flush();
    }

    private byte[] read() throws Exception{
        byte[] bytes = new byte[8192];
        int read = this.in.read(bytes);
        if (read <= 0) return new byte[0]; // or return null, or something, read might be -1 when there was no data.
        byte[] readBytes = new byte[read];
        System.arraycopy(bytes, 0, readBytes, 0, read);
        return bytes;
    }

}

ConnectionCallback class:

package connection;

import connection.data.ByteArray;
import crypto.CryptoUtils;
import util.function.Callback;
import java.security.KeyFactory;
import java.security.PublicKey;
import java.security.spec.X509EncodedKeySpec;

public class ConnectionCallback implements Callback {

    private Connection connection;

    public ConnectionCallback(Connection connection){
        this.connection = connection;
    }

    @Override
    public void run(Object data) throws Exception {
        ByteArray bytes = (ByteArray) data;
        byte[] dataToBytes = CryptoUtils.decode(bytes.getBytes());
        if(this.connection.protocol.getSharedKey() == null){
            X509EncodedKeySpec spec = new X509EncodedKeySpec(dataToBytes);
            KeyFactory kf = KeyFactory.getInstance("RSA");
            PublicKey publicKey = kf.generatePublic(spec);
            this.connection.protocol.setSharedKey(publicKey);
        } else {
            //this.so = StrongboxObject.parse(new String(bytes.getBytes()));
        }
    }

}

RS算法class:

package crypto.algorithm.asymmetric.rsa;

import crypto.CryptoUtils;
import crypto.algorithm.asymmetric.AssimetricalAlgorithm;
import javax.crypto.Cipher;
import java.security.*;
import java.util.Base64;

public class RSAAlgorithm extends AssimetricalAlgorithm {

    private KeyPairGenerator keyGen;

    public RSAAlgorithm(int keyLength) throws Exception {
        super();
        this.keyGen = KeyPairGenerator.getInstance("RSA");
        this.keyGen.initialize(keyLength);
        this.generateKeys();
    }

    @Override
    public void generateKeys() {
        KeyPair pair = this.keyGen.generateKeyPair();
        super.setPublicKey(pair.getPublic());
        super.setPrivateKey(pair.getPrivate());
    }

    @Override
    public byte[] encrypt(byte[] message) {
        try {
            super.cipher.init(Cipher.ENCRYPT_MODE, (PublicKey) super.getSharedKey());
            return CryptoUtils.encode(super.cipher.doFinal(message));
        } catch (Exception e) {
            e.printStackTrace();
        }
        return new byte[0];
    }

    @Override
    public byte[] decrypt(byte[] message) {
        message = CryptoUtils.decode(message);
        try {
            super.cipher.init(Cipher.DECRYPT_MODE, (PrivateKey) super.getPrivateKey());
            return super.cipher.doFinal(message);
        } catch (Exception e) {
            e.printStackTrace();
        }
        return new byte[0];
    }

}

ProtectedConnectionProtocol class:

package connection.protocols;

import protocol.connection.ConnectionProtocol;
import crypto.algorithm.asymmetric.AssimetricalAlgorithm;

public class ProtectedConnectionProtocol extends ConnectionProtocol {

    private AssimetricalAlgorithm algorithm;

    public ProtectedConnectionProtocol(AssimetricalAlgorithm algorithm){
        this.algorithm = algorithm;
    }

    @Override
    public Object getPublicKey() {
        return this.algorithm.getPublicKey();
    }

    @Override
    public Object getPrivateKey() {
        return this.algorithm.getPrivateKey();
    }

    @Override
    public Object getSharedKey() {
        return this.algorithm.getSharedKey();
    }

    @Override
    public void setSharedKey(Object sharedKey){
        this.algorithm.setSharedKey(sharedKey);
    }

    @Override
    public byte[] decrypt(byte[] message) {
        return this.algorithm.decrypt(message);
    }

    @Override
    public byte[] encrypt(byte[] message) {
        return this.algorithm.encrypt(message);
    }

}

CryptoUtils class:

package crypto;

import java.util.Base64;

public class CryptoUtils {

    public static byte[] encode(byte[] data){
        return Base64.getEncoder().encode(data);
    }

    public static byte[] decode(byte[] data){
        return Base64.getDecoder().decode(data);
    }

}

2019 年 5 月 9 日更新
代码更新相同异常:

MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQCcrbJGHqpJdhDbVoZCJ0bucb8YnvcVWx9HIUfJOgmAKIuTmw1VUCk85ztqDq0VP2k6IP2bSD5MegR10FtqGtGEQrv+m0eNgbvE3O7czUzvedb5wKbA8eiSPbcX8JElobOhrolOb8JQRQzWAschBNp4MDljlu+0KZQHtZa6pPYJ0wIDAQAB
    java.lang.IllegalArgumentException: Illegal base64 character 0
        at java.base/java.util.Base64$Decoder.decode0(Base64.java:743)
        at java.base/java.util.Base64$Decoder.decode(Base64.java:535)
        at crypto.CryptoUtils.decode(CryptoUtils.java:12)
        at connection.ConnectionCallback.run(ConnectionCallback.java:21)
        at connection.Connection.run(Connection.java:42)
        at java.base/java.lang.Thread.run(Thread.java:834)

请帮帮我,我对此很恼火,只有 2 天的赏金,我宁愿把我的赏金给帮助我找到解决这个问题的人,也不愿失去它。

这可能是你的读取方式造成的:

private byte[] read() throws Exception{
    byte[] bytes = new byte[8192];
    this.in.read(bytes);
    return bytes;
}

您总是读入 8192 字节的数组,即使输入流中没有足够的字节。 this.in.read(bytes) returns 读取的字节数,您应该使用该值并且仅使用该数组中的字节数,忽略其余部分 - 因为数组的其余部分将只是 0,所以当你尝试从中解码 base64 你会得到 java.lang.IllegalArgumentException: Illegal base64 character 0

因此,在读取字节时,您只需将它们复制到新数组即可:

private byte[] read() throws Exception{
    byte[] bytes = new byte[8192];
    int read = this.in.read(bytes);
    if (read <= 0) return new byte[0]; // or return null, or something, read might be -1 when there was no data.
    byte[] readBytes = new byte[read]
    System.arraycopy(bytes, 0, readBytes, 0, read)
    return readBytes;
}

请注意,这样的读取实际上对性能来说是一个非常糟糕的主意,因为您为每次读取分配了很多东西。像 netty 这样的更高级的库有自己的字节缓冲区,具有单独的 read/write 位置,并且只将所有内容存储在单个自动调整大小的字节数组中,但首先要让它工作,如果你对性能有任何问题,请记住这是您可能会找到解决方案的地方之一。

同样在您的 ByteArray 中,您正在将两个数组复制到同一个位置:

    for(int i = 0; i < this.bytes.length; i++){
        bytes1[i] = this.bytes[i];
    }
    for(int i = 0; i < data.length; i++){
        bytes1[i] = data[i]; // this loop starts from 0 too
    }

你需要在第二个中使用 i + this.bytes.length。 (最好使用 System.arrayCopy)

public byte[] joinArrays(byte[] array1, byte[] array2) {
    byte[] array = new byte[array1.length + array2.length];
    System.arraycopy(array1, 0, array, 0, array1.length);
    System.arraycopy(array2, 0, array, array1.length, array2.length);
    return array;
}

然后就是:

public void add(byte[] data) {
    if(this.bytes == null) this.bytes = new byte[0];
    this.bytes = joinArrays(this.bytes, data);
}

也像其他答案一样 - 将刷新方法更改为仅将字段设置为空可能是个好主意,或者甚至更好,只需删除该方法,因为我没有看到它被使用,你可以无论如何创建这个对象的新实例。

我查看了您的代码,发现问题出在 ByteArray class 中的 add() 方法。让我告诉你,(查看评论)

原文:字节数组

public void add(byte[] data){
    if(this.bytes == null)
        this.bytes = new byte[data.length]; 
    byte[] bytes1 = new byte[this.bytes.length + data.length];
    for(int i = 0; i < this.bytes.length; i++){
        bytes1[i] = this.bytes[i]; // when this.bytes is null you are adding data.length amount of 0, which is not something you want i guess. This prevents the base64 decoder to decode
    }
    for(int i = 0; i < data.length; i++){
        bytes1[i] = data[i];
    }
    this.bytes = bytes1;
}

解决方案:字节数组

public void add(byte[] data){
    if(this.bytes == null) {
        this.bytes = data; // just store it because the field is null
    } else {
        byte[] bytes1 = new byte[this.bytes.length + data.length];
        for (int i = 0; i < this.bytes.length; i++) {
            bytes1[i] = this.bytes[i];
        }
        for (int i = 0; i < data.length; i++) {
            bytes1[i] = data[i];
        }
        this.bytes = bytes1;
    }
}

public void flush(){
    this.bytes = null; // Important
}

EDIT

观察Connection class中读取字节的代码后,我发现它在最后读取了不必要的0字节。所以我想出了以下解决方法,

重构:连接

...

public abstract class Connection implements Runnable {

...

@Override
public void run() {
    while(isConnected){
        try {
            ByteArray data = new ByteArray();
            while(this.in.available() > 0){
                byte[] read = this.read();
                if (read != null) {
                    data.add(read);
                }
            }
            if(data.getBytes() != null){
                callback.run(data);
            }
        } catch (Exception e){
            e.printStackTrace();
            break;
        }
    }
}

...

private byte[] read() throws Exception{
    byte[] bytes = new byte[this.in.available()];
    int read = this.in.read(bytes);
    if (read <= 0) return null; // or return null, or something, read might be -1 when there was no data.
    return bytes; // just returning the read bytes is fine. you don't need to copy.
}

}