diff --git a/audio/wav/decode.go b/audio/wav/decode.go
index 33e0fb1d1c9e83dd572eea4aa3e0d5a176ad4cea..116b4264b2702a4d724c8c3c0076c7a790d0b919 100644
--- a/audio/wav/decode.go
+++ b/audio/wav/decode.go
@@ -116,47 +116,37 @@ func (s *decoder) Stream(samples [][2]float64) (n int, ok bool) {
 	if s.err != nil || s.pos >= s.h.DataSize {
 		return 0, false
 	}
-	var frameWidth int
-	switch {
-	case s.h.BitsPerSample == 8 && s.h.NumChans == 1:
-		frameWidth = 1
-	case s.h.BitsPerSample == 8 && s.h.NumChans >= 2:
-		frameWidth = int(s.h.NumChans)
-	case s.h.BitsPerSample == 16 && s.h.NumChans == 1:
-		frameWidth = 2
-	case s.h.BitsPerSample == 16 && s.h.NumChans >= 2:
-		frameWidth = int(s.h.NumChans) * 2
-	}
-	p := make([]byte, len(samples)*frameWidth)
+	bytesPerFrame := int(s.h.BytesPerFrame)
+	p := make([]byte, len(samples)*bytesPerFrame)
 	n, err := s.rsc.Read(p)
 	if err != nil {
 		s.err = err
 	}
 	switch {
 	case s.h.BitsPerSample == 8 && s.h.NumChans == 1:
-		for i, j := 0, 0; i < n-frameWidth; i, j = i+frameWidth, j+1 {
+		for i, j := 0, 0; i < n-bytesPerFrame; i, j = i+bytesPerFrame, j+1 {
 			val := float64(p[i])/(1<<8-1)*2 - 1
 			samples[j][0] = val
 			samples[j][1] = val
 		}
 	case s.h.BitsPerSample == 8 && s.h.NumChans >= 2:
-		for i, j := 0, 0; i < n-frameWidth; i, j = i+frameWidth, j+1 {
+		for i, j := 0, 0; i < n-bytesPerFrame; i, j = i+bytesPerFrame, j+1 {
 			samples[j][0] = float64(p[i+0])/(1<<8-1)*2 - 1
 			samples[j][1] = float64(p[i+1])/(1<<8-1)*2 - 1
 		}
 	case s.h.BitsPerSample == 16 && s.h.NumChans == 1:
-		for i, j := 0, 0; i < n-frameWidth; i, j = i+frameWidth, j+1 {
+		for i, j := 0, 0; i < n-bytesPerFrame; i, j = i+bytesPerFrame, j+1 {
 			val := float64(int16(p[i+0])+int16(p[i+1])*(1<<8)) / (1<<15 - 1)
 			samples[j][0] = val
 			samples[j][1] = val
 		}
 	case s.h.BitsPerSample == 16 && s.h.NumChans >= 2:
-		for i, j := 0, 0; i <= n-frameWidth; i, j = i+frameWidth, j+1 {
+		for i, j := 0, 0; i <= n-bytesPerFrame; i, j = i+bytesPerFrame, j+1 {
 			samples[j][0] = float64(int16(p[i+0])+int16(p[i+1])*(1<<8)) / (1<<15 - 1)
 			samples[j][1] = float64(int16(p[i+2])+int16(p[i+3])*(1<<8)) / (1<<15 - 1)
 		}
 	}
-	return n / frameWidth, true
+	return n / bytesPerFrame, true
 }
 
 func (s *decoder) Close() error {