From fabee929b51c70e16eb0a040f9b2376229c53c80 Mon Sep 17 00:00:00 2001 From: Jan Tytgat Date: Mon, 13 Jan 2025 14:56:36 +0100 Subject: [PATCH] crypto.go: - Add documentation - Add tests --- pkg/transcrypt/crypto.go | 11 +- pkg/transcrypt/crypto_test.go | 237 ++++++++++++++++++++++++++++++++++ 2 files changed, 247 insertions(+), 1 deletion(-) create mode 100644 pkg/transcrypt/crypto_test.go diff --git a/pkg/transcrypt/crypto.go b/pkg/transcrypt/crypto.go index 7e50d92..7de4cbd 100644 --- a/pkg/transcrypt/crypto.go +++ b/pkg/transcrypt/crypto.go @@ -12,6 +12,9 @@ import ( "golang.org/x/crypto/hkdf" ) +// createCryptoConfig creates a sio.config from the supplied key, cipher and optional salt. +// It returns an error if either key or cipher is empty. +// It also returns an error if the supplied salt is less than 12 bytes long. func createCryptoConfig(key string, cipher []byte, salt []byte) (sio.Config, error) { if key == "" { return sio.Config{}, errors.New("key is empty") @@ -29,6 +32,10 @@ func createCryptoConfig(key string, cipher []byte, salt []byte) (sio.Config, err } } + if len(salt) < 12 { + return sio.Config{}, fmt.Errorf("salt needs to be at least 12 bytes, got %d", len(salt)) + } + // Create encryption key kdf := hkdf.New(sha256.New, []byte(key), salt[:12], nil) var encKey [32]byte @@ -43,7 +50,7 @@ func createCryptoConfig(key string, cipher []byte, salt []byte) (sio.Config, err }, nil } -// createSalt creates a random salt for use with the encrypt/decrypt functionality +// createSalt creates a random salt for use with the encrypt/decrypt functionality. func createSalt() ([]byte, error) { var nonce [12]byte if _, err := io.ReadFull(rand.Reader, nonce[:]); err != nil { @@ -53,6 +60,8 @@ func createSalt() ([]byte, error) { return nonce[:], nil } +// getKindFromString converts a string to its representative reflect.Kind. +// It returns a reflect.Invalid by default if the supplied string cannot be found. func getKindForString(s string) reflect.Kind { switch s { case "bool": diff --git a/pkg/transcrypt/crypto_test.go b/pkg/transcrypt/crypto_test.go new file mode 100644 index 0000000..87af8ea --- /dev/null +++ b/pkg/transcrypt/crypto_test.go @@ -0,0 +1,237 @@ +package transcrypt + +import ( + "reflect" + "testing" +) + +func Test_createCryptoConfig(t *testing.T) { + type args struct { + key string + cipher []byte + salt []byte + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "empty_key", + args: args{ + key: "test", + cipher: nil, + salt: nil, + }, + wantErr: true, + }, + { + name: "empty_cipher", + args: args{ + key: "test", + cipher: nil, + salt: nil, + }, + wantErr: true, + }, + { + name: "invalid_salt", + args: args{ + key: "test", + cipher: []byte("cipher"), + salt: []byte("salt"), + }, + wantErr: true, + }, + { + name: "valid", + args: args{ + key: "test", + cipher: []byte("cipher"), + salt: nil, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := createCryptoConfig(tt.args.key, tt.args.cipher, tt.args.salt) + if (err != nil) != tt.wantErr { + t.Errorf("createCryptoConfig() error = %v, wantErr %v", err, tt.wantErr) + return + } + }) + } +} + +func Test_createSalt(t *testing.T) { + tests := []struct { + name string + wantErr bool + }{ + { + name: "succes", + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := createSalt() + if (err != nil) != tt.wantErr { + t.Errorf("createSalt() error = %v, wantErr %v", err, tt.wantErr) + return + } + }) + } +} + +func Test_getKindForString(t *testing.T) { + tests := []struct { + name string + kind string + want reflect.Kind + }{ + { + name: "bool", + kind: "bool", + want: reflect.Bool, + }, + { + name: "int", + kind: "int", + want: reflect.Int, + }, + { + name: "int8", + kind: "int8", + want: reflect.Int8, + }, + { + name: "int16", + kind: "int16", + want: reflect.Int16, + }, + { + name: "int32", + kind: "int32", + want: reflect.Int32, + }, + { + name: "int64", + kind: "int64", + want: reflect.Int64, + }, + { + name: "uint", + kind: "uint", + want: reflect.Uint, + }, + { + name: "uint8", + kind: "uint8", + want: reflect.Uint8, + }, + { + name: "uint16", + kind: "uint16", + want: reflect.Uint16, + }, + { + name: "uint32", + kind: "uint32", + want: reflect.Uint32, + }, + { + name: "uint64", + kind: "uint64", + want: reflect.Uint64, + }, + { + name: "uintptr", + kind: "uintptr", + want: reflect.Uintptr, + }, + { + name: "float32", + kind: "float32", + want: reflect.Float32, + }, + { + name: "float64", + kind: "float64", + want: reflect.Float64, + }, + { + name: "complex64", + kind: "complex64", + want: reflect.Complex64, + }, + { + name: "complex128", + kind: "complex128", + want: reflect.Complex128, + }, + { + name: "array", + kind: "array", + want: reflect.Array, + }, + { + name: "chan", + kind: "chan", + want: reflect.Chan, + }, + { + name: "func", + kind: "func", + want: reflect.Func, + }, + { + name: "interface", + kind: "interface", + want: reflect.Interface, + }, + { + name: "map", + kind: "map", + want: reflect.Map, + }, + { + name: "pointer", + kind: "pointer", + want: reflect.Pointer, + }, + { + name: "slice", + kind: "slice", + want: reflect.Slice, + }, + { + name: "string", + kind: "string", + want: reflect.String, + }, + { + name: "struct", + kind: "struct", + want: reflect.Struct, + }, + { + name: "unsafepointer", + kind: "unsafepointer", + want: reflect.UnsafePointer, + }, + { + name: "default", + kind: "default", + want: reflect.Invalid, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := getKindForString(tt.kind); got != tt.want { + t.Errorf("getKindForString() = %v, want %v", got, tt.want) + } + }) + } +}