crypto.go:

- Add documentation
- Add tests
This commit is contained in:
Jan Tytgat
2025-01-13 14:56:36 +01:00
parent 92feadbb70
commit fabee929b5
2 changed files with 247 additions and 1 deletions

View File

@ -12,6 +12,9 @@ import (
"golang.org/x/crypto/hkdf" "golang.org/x/crypto/hkdf"
) )
// createCryptoConfig creates a sio.config from the supplied key, cipher and optional salt.
// It returns an error if either key or cipher is empty.
// It also returns an error if the supplied salt is less than 12 bytes long.
func createCryptoConfig(key string, cipher []byte, salt []byte) (sio.Config, error) { func createCryptoConfig(key string, cipher []byte, salt []byte) (sio.Config, error) {
if key == "" { if key == "" {
return sio.Config{}, errors.New("key is empty") return sio.Config{}, errors.New("key is empty")
@ -29,6 +32,10 @@ func createCryptoConfig(key string, cipher []byte, salt []byte) (sio.Config, err
} }
} }
if len(salt) < 12 {
return sio.Config{}, fmt.Errorf("salt needs to be at least 12 bytes, got %d", len(salt))
}
// Create encryption key // Create encryption key
kdf := hkdf.New(sha256.New, []byte(key), salt[:12], nil) kdf := hkdf.New(sha256.New, []byte(key), salt[:12], nil)
var encKey [32]byte var encKey [32]byte
@ -43,7 +50,7 @@ func createCryptoConfig(key string, cipher []byte, salt []byte) (sio.Config, err
}, nil }, nil
} }
// createSalt creates a random salt for use with the encrypt/decrypt functionality // createSalt creates a random salt for use with the encrypt/decrypt functionality.
func createSalt() ([]byte, error) { func createSalt() ([]byte, error) {
var nonce [12]byte var nonce [12]byte
if _, err := io.ReadFull(rand.Reader, nonce[:]); err != nil { if _, err := io.ReadFull(rand.Reader, nonce[:]); err != nil {
@ -53,6 +60,8 @@ func createSalt() ([]byte, error) {
return nonce[:], nil return nonce[:], nil
} }
// getKindFromString converts a string to its representative reflect.Kind.
// It returns a reflect.Invalid by default if the supplied string cannot be found.
func getKindForString(s string) reflect.Kind { func getKindForString(s string) reflect.Kind {
switch s { switch s {
case "bool": case "bool":

View File

@ -0,0 +1,237 @@
package transcrypt
import (
"reflect"
"testing"
)
func Test_createCryptoConfig(t *testing.T) {
type args struct {
key string
cipher []byte
salt []byte
}
tests := []struct {
name string
args args
wantErr bool
}{
{
name: "empty_key",
args: args{
key: "test",
cipher: nil,
salt: nil,
},
wantErr: true,
},
{
name: "empty_cipher",
args: args{
key: "test",
cipher: nil,
salt: nil,
},
wantErr: true,
},
{
name: "invalid_salt",
args: args{
key: "test",
cipher: []byte("cipher"),
salt: []byte("salt"),
},
wantErr: true,
},
{
name: "valid",
args: args{
key: "test",
cipher: []byte("cipher"),
salt: nil,
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := createCryptoConfig(tt.args.key, tt.args.cipher, tt.args.salt)
if (err != nil) != tt.wantErr {
t.Errorf("createCryptoConfig() error = %v, wantErr %v", err, tt.wantErr)
return
}
})
}
}
func Test_createSalt(t *testing.T) {
tests := []struct {
name string
wantErr bool
}{
{
name: "succes",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := createSalt()
if (err != nil) != tt.wantErr {
t.Errorf("createSalt() error = %v, wantErr %v", err, tt.wantErr)
return
}
})
}
}
func Test_getKindForString(t *testing.T) {
tests := []struct {
name string
kind string
want reflect.Kind
}{
{
name: "bool",
kind: "bool",
want: reflect.Bool,
},
{
name: "int",
kind: "int",
want: reflect.Int,
},
{
name: "int8",
kind: "int8",
want: reflect.Int8,
},
{
name: "int16",
kind: "int16",
want: reflect.Int16,
},
{
name: "int32",
kind: "int32",
want: reflect.Int32,
},
{
name: "int64",
kind: "int64",
want: reflect.Int64,
},
{
name: "uint",
kind: "uint",
want: reflect.Uint,
},
{
name: "uint8",
kind: "uint8",
want: reflect.Uint8,
},
{
name: "uint16",
kind: "uint16",
want: reflect.Uint16,
},
{
name: "uint32",
kind: "uint32",
want: reflect.Uint32,
},
{
name: "uint64",
kind: "uint64",
want: reflect.Uint64,
},
{
name: "uintptr",
kind: "uintptr",
want: reflect.Uintptr,
},
{
name: "float32",
kind: "float32",
want: reflect.Float32,
},
{
name: "float64",
kind: "float64",
want: reflect.Float64,
},
{
name: "complex64",
kind: "complex64",
want: reflect.Complex64,
},
{
name: "complex128",
kind: "complex128",
want: reflect.Complex128,
},
{
name: "array",
kind: "array",
want: reflect.Array,
},
{
name: "chan",
kind: "chan",
want: reflect.Chan,
},
{
name: "func",
kind: "func",
want: reflect.Func,
},
{
name: "interface",
kind: "interface",
want: reflect.Interface,
},
{
name: "map",
kind: "map",
want: reflect.Map,
},
{
name: "pointer",
kind: "pointer",
want: reflect.Pointer,
},
{
name: "slice",
kind: "slice",
want: reflect.Slice,
},
{
name: "string",
kind: "string",
want: reflect.String,
},
{
name: "struct",
kind: "struct",
want: reflect.Struct,
},
{
name: "unsafepointer",
kind: "unsafepointer",
want: reflect.UnsafePointer,
},
{
name: "default",
kind: "default",
want: reflect.Invalid,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := getKindForString(tt.kind); got != tt.want {
t.Errorf("getKindForString() = %v, want %v", got, tt.want)
}
})
}
}