diff --git a/pkg/transcrypt/crypto.go b/pkg/transcrypt/crypto.go index 7de4cbd..bde6b69 100644 --- a/pkg/transcrypt/crypto.go +++ b/pkg/transcrypt/crypto.go @@ -2,7 +2,11 @@ package transcrypt import ( "crypto/rand" + "crypto/rsa" "crypto/sha256" + "crypto/x509" + "encoding/hex" + "encoding/pem" "errors" "fmt" "io" @@ -12,6 +16,38 @@ import ( "golang.org/x/crypto/hkdf" ) +// CreateHexKey generates a random key which can be used for encryption. +// It generates a RSA Private Key with the supplied bitSize, and converts it to a hex-encoded PEM Block. +func CreateHexKey(bitSize int) (string, error) { + if bitSize < 12 { + return "", errors.New("bit size must be at least 12") + } + var err error + var privKey *rsa.PrivateKey + + var reader = rand.Reader + + if privKey, err = rsa.GenerateKey(reader, bitSize); err != nil { + return "", err + } + + return hex.EncodeToString(pem.EncodeToMemory( + &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(privKey), + })), nil +} + +// CreateSalt creates a random 12-byte salt for use with the encrypt/decrypt functionality. +func CreateSalt() ([]byte, error) { + var nonce [12]byte + if _, err := io.ReadFull(rand.Reader, nonce[:]); err != nil { + return nil, fmt.Errorf("failed to read random data for nonce: %w", err) + } + + return nonce[:], nil +} + // createCryptoConfig creates a sio.config from the supplied key, cipher and optional salt. // It returns an error if either key or cipher is empty. // It also returns an error if the supplied salt is less than 12 bytes long. @@ -27,7 +63,7 @@ func createCryptoConfig(key string, cipher []byte, salt []byte) (sio.Config, err var err error // If salt is nil, create a new salt that can be used for encryption if salt == nil { - if salt, err = createSalt(); err != nil { + if salt, err = CreateSalt(); err != nil { return sio.Config{}, fmt.Errorf("could not create salt: %w", err) } } @@ -50,16 +86,6 @@ func createCryptoConfig(key string, cipher []byte, salt []byte) (sio.Config, err }, nil } -// createSalt creates a random salt for use with the encrypt/decrypt functionality. -func createSalt() ([]byte, error) { - var nonce [12]byte - if _, err := io.ReadFull(rand.Reader, nonce[:]); err != nil { - return nil, fmt.Errorf("failed to read random data for nonce: %w", err) - } - - return nonce[:], nil -} - // getKindFromString converts a string to its representative reflect.Kind. // It returns a reflect.Invalid by default if the supplied string cannot be found. func getKindForString(s string) reflect.Kind { diff --git a/pkg/transcrypt/crypto_test.go b/pkg/transcrypt/crypto_test.go index 87af8ea..d6fbb2c 100644 --- a/pkg/transcrypt/crypto_test.go +++ b/pkg/transcrypt/crypto_test.go @@ -5,6 +5,78 @@ import ( "testing" ) +func Test_CreateHexKey(t *testing.T) { + type args struct { + bitSize int + } + tests := []struct { + name string + bitSize int + wantErr bool + }{ + { + name: "invalid_size_0", + bitSize: 0, + wantErr: true, + }, + { + name: "invalid_size_11", + bitSize: 11, + wantErr: true, + }, + { + name: "valid_size_12", + bitSize: 12, + wantErr: false, + }, + { + name: "valid_size_256", + bitSize: 256, + wantErr: false, + }, + { + name: "valid_size_1024", + bitSize: 1024, + wantErr: false, + }, + { + name: "valid_size_2048", + bitSize: 2048, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := CreateHexKey(tt.bitSize) + if (err != nil) != tt.wantErr { + t.Errorf("CreateHexKey() error = %v, wantErr %v", err, tt.wantErr) + return + } + }) + } +} + +func Test_CreateSalt(t *testing.T) { + tests := []struct { + name string + wantErr bool + }{ + { + name: "success", + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := CreateSalt() + if (err != nil) != tt.wantErr { + t.Errorf("CreateSalt() error = %v, wantErr %v", err, tt.wantErr) + return + } + }) + } +} + func Test_createCryptoConfig(t *testing.T) { type args struct { key string @@ -64,27 +136,6 @@ func Test_createCryptoConfig(t *testing.T) { } } -func Test_createSalt(t *testing.T) { - tests := []struct { - name string - wantErr bool - }{ - { - name: "succes", - wantErr: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - _, err := createSalt() - if (err != nil) != tt.wantErr { - t.Errorf("createSalt() error = %v, wantErr %v", err, tt.wantErr) - return - } - }) - } -} - func Test_getKindForString(t *testing.T) { tests := []struct { name string