feat: support float and better test case
tests: use function to generate data, in future can use JSON test data record to do it matrix: use f32 type
This commit is contained in:
parent
dd32b816d3
commit
3c2dbe0d01
4 changed files with 124 additions and 41 deletions
|
|
@ -28,7 +28,7 @@ pub use matrix::{Matrix, MatrixMath};
|
||||||
|
|
||||||
pub fn test() {
|
pub fn test() {
|
||||||
println!("Testing code here");
|
println!("Testing code here");
|
||||||
let m = Matrix::from(vec![1,2,3,4,5]);
|
let m = Matrix::from(vec![1.0,2.0,3.0,4.0,5.0]);
|
||||||
m.transpose();
|
m.transpose();
|
||||||
m.determinant();
|
m.determinant();
|
||||||
}
|
}
|
||||||
|
|
|
||||||
105
src/matrix.rs
105
src/matrix.rs
|
|
@ -10,14 +10,14 @@
|
||||||
//! println!("m1 + m2 =\n{}", m_add);
|
//! println!("m1 + m2 =\n{}", m_add);
|
||||||
//! ```
|
//! ```
|
||||||
//! TODO:: Create matrix multiplication method
|
//! TODO:: Create matrix multiplication method
|
||||||
|
use core::ops::AddAssign;
|
||||||
use crate::error::{MatrixSetValueError, ParseMatrixError};
|
use crate::error::{MatrixSetValueError, ParseMatrixError};
|
||||||
use std::{
|
use std::{
|
||||||
fmt::Display,
|
fmt::Display,
|
||||||
ops::{Add, Mul, Sub},
|
ops::{Add, Mul, Sub},
|
||||||
str::FromStr,
|
str::FromStr,
|
||||||
};
|
};
|
||||||
#[derive(Debug, PartialEq, Eq)]
|
#[derive(Debug, PartialEq)]
|
||||||
pub struct Matrix {
|
pub struct Matrix {
|
||||||
/// Number of rows in matrix.
|
/// Number of rows in matrix.
|
||||||
pub nrows: usize,
|
pub nrows: usize,
|
||||||
|
|
@ -26,12 +26,16 @@ pub struct Matrix {
|
||||||
pub ncols: usize,
|
pub ncols: usize,
|
||||||
|
|
||||||
/// Data stored in the matrix, you should not access this directly
|
/// Data stored in the matrix, you should not access this directly
|
||||||
data: Vec<Vec<i32>>,
|
data: Vec<Vec<f32>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait MatrixMath {
|
pub trait MatrixMath {
|
||||||
fn inverse(&self) -> Matrix {
|
fn inverse(&self) -> Option<Matrix> {
|
||||||
(1 / (self.determinant())) * &self.adjoint()
|
let det_m = self.determinant();
|
||||||
|
if det_m == 0.0 {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
Some((1.0 / det_m) * &self.adjoint())
|
||||||
}
|
}
|
||||||
/// Finds the matrix of cofactors for any N-by-N matrix
|
/// Finds the matrix of cofactors for any N-by-N matrix
|
||||||
fn cofactor(&self) -> Matrix {
|
fn cofactor(&self) -> Matrix {
|
||||||
|
|
@ -42,7 +46,7 @@ pub trait MatrixMath {
|
||||||
todo!();
|
todo!();
|
||||||
}
|
}
|
||||||
/// Finds the determinant of any N-by-N matrix.
|
/// Finds the determinant of any N-by-N matrix.
|
||||||
fn determinant(&self) -> i32 {
|
fn determinant(&self) -> f32 {
|
||||||
todo!();
|
todo!();
|
||||||
}
|
}
|
||||||
/// Finds the transpose of any matrix.
|
/// Finds the transpose of any matrix.
|
||||||
|
|
@ -55,32 +59,58 @@ pub trait MatrixMath {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
impl MatrixMath for Matrix {
|
impl MatrixMath for Matrix {
|
||||||
|
fn cofactor(&self) -> Matrix {
|
||||||
|
let mut d: Vec<Vec<f32>> = Vec::new();
|
||||||
|
for (i, r) in self.data.iter().enumerate() {
|
||||||
|
let mut nr: Vec<f32> = Vec::new();
|
||||||
|
for (j, v) in r.iter().enumerate() {
|
||||||
|
let count = self.ncols * i + j;
|
||||||
|
let nv = if count % 2 == 0 { -*v } else { *v };
|
||||||
|
nr.push(nv);
|
||||||
|
}
|
||||||
|
d.push(nr);
|
||||||
|
}
|
||||||
|
Matrix::new(d)
|
||||||
|
}
|
||||||
|
fn minor(&self) -> Matrix {
|
||||||
|
let mut d: Vec<Vec<f32>> = Vec::new();
|
||||||
|
for (i, r) in self.data.iter().enumerate() {
|
||||||
|
let mut nr: Vec<f32> = Vec::new();
|
||||||
|
for (j, v) in r.iter().enumerate() {
|
||||||
|
let count = self.ncols * i + j;
|
||||||
|
let nv = if count % 2 == 0 { -*v } else { *v };
|
||||||
|
nr.push(nv * self.splice(j, i).determinant());
|
||||||
|
}
|
||||||
|
d.push(nr);
|
||||||
|
}
|
||||||
|
Matrix::new(d)
|
||||||
|
}
|
||||||
/// Evaluates any N-by-N matrix.
|
/// Evaluates any N-by-N matrix.
|
||||||
///
|
///
|
||||||
/// This function panics if the matrix is not square!
|
/// This function panics if the matrix is not square!
|
||||||
fn determinant(&self) -> i32 {
|
fn determinant(&self) -> f32 {
|
||||||
if !self.is_square() {
|
if !self.is_square() {
|
||||||
panic!()
|
panic!()
|
||||||
};
|
};
|
||||||
if self.nrows == 2 && self.ncols == 2 {
|
if self.nrows == 2 && self.ncols == 2 {
|
||||||
return self.data[0][0] * self.data[1][1] - self.data[0][1] * self.data[1][0];
|
return self.data[0][0] * self.data[1][1] - self.data[0][1] * self.data[1][0];
|
||||||
}
|
}
|
||||||
let mut tmp = 0;
|
let mut tmp: f32 = 0.0;
|
||||||
for (i, n) in self.data[0].iter().enumerate() {
|
for (i, n) in self.data[0].iter().enumerate() {
|
||||||
let mult = if i % 2 == 0 { -*n } else { *n };
|
let mult = if i % 2 == 0 { -*n } else { *n };
|
||||||
let eval = self.splice(i).determinant();
|
let eval = self.splice(i, 0).determinant();
|
||||||
tmp += mult * eval;
|
tmp = tmp + mult * f32::from(eval);
|
||||||
}
|
}
|
||||||
tmp
|
tmp.into()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Evaluates the tranpose of the matrix.
|
/// Evaluates the tranpose of the matrix.
|
||||||
///
|
///
|
||||||
/// Each row becomes a column, each column becomes a row.
|
/// Each row becomes a column, each column becomes a row.
|
||||||
fn transpose(&self) -> Matrix {
|
fn transpose(&self) -> Matrix {
|
||||||
let mut new_data = Vec::<Vec<i32>>::new();
|
let mut new_data = Vec::<Vec<f32>>::new();
|
||||||
for i in 0..self.nrows {
|
for i in 0..self.nrows {
|
||||||
let mut new_row = Vec::<i32>::new();
|
let mut new_row = Vec::<f32>::new();
|
||||||
for j in 0..self.ncols {
|
for j in 0..self.ncols {
|
||||||
new_row.push(self.data[j][i]);
|
new_row.push(self.data[j][i]);
|
||||||
}
|
}
|
||||||
|
|
@ -100,17 +130,26 @@ impl Matrix {
|
||||||
///
|
///
|
||||||
/// TODOs
|
/// TODOs
|
||||||
/// - Add row length check
|
/// - Add row length check
|
||||||
pub fn new(data: Vec<Vec<i32>>) -> Matrix {
|
pub fn new(data: Vec<Vec<f32>>) -> Matrix {
|
||||||
|
let mut d: Vec<Vec<f32>> = Vec::new();
|
||||||
|
|
||||||
|
for r in &data {
|
||||||
|
let mut nr = vec![];
|
||||||
|
for x in r {
|
||||||
|
nr.push(*x);
|
||||||
|
}
|
||||||
|
d.push(nr);
|
||||||
|
}
|
||||||
Matrix {
|
Matrix {
|
||||||
nrows: data.len(),
|
nrows: data.len(),
|
||||||
ncols: data[0].len(),
|
ncols: data[0].len(),
|
||||||
data,
|
data: d,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
/// Query one element at selected position.
|
/// Query one element at selected position.
|
||||||
///
|
///
|
||||||
/// Returns `None` if index is out of bounds.
|
/// Returns `None` if index is out of bounds.
|
||||||
pub fn get(&self, row_index: usize, column_index: usize) -> Option<i32> {
|
pub fn get(&self, row_index: usize, column_index: usize) -> Option<f32> {
|
||||||
let r = self.data.get(row_index)?;
|
let r = self.data.get(row_index)?;
|
||||||
let n = r.get(column_index)?;
|
let n = r.get(column_index)?;
|
||||||
Some(*n)
|
Some(*n)
|
||||||
|
|
@ -123,7 +162,7 @@ impl Matrix {
|
||||||
&mut self,
|
&mut self,
|
||||||
row_index: usize,
|
row_index: usize,
|
||||||
column_index: usize,
|
column_index: usize,
|
||||||
new_data: i32,
|
new_data: f32,
|
||||||
) -> Result<(), MatrixSetValueError> {
|
) -> Result<(), MatrixSetValueError> {
|
||||||
self.data[row_index][column_index] = new_data;
|
self.data[row_index][column_index] = new_data;
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|
@ -133,13 +172,13 @@ impl Matrix {
|
||||||
pub fn is_square(&self) -> bool {
|
pub fn is_square(&self) -> bool {
|
||||||
self.nrows == self.ncols
|
self.nrows == self.ncols
|
||||||
}
|
}
|
||||||
pub fn splice(&self, at_index: usize) -> Matrix {
|
fn splice(&self, at_index: usize, at_row: usize) -> Matrix {
|
||||||
let mut data: Vec<Vec<i32>> = Vec::new();
|
let mut data: Vec<Vec<f32>> = Vec::new();
|
||||||
for i in 0..self.data.len() {
|
for i in 0..self.data.len() {
|
||||||
if i == 0 {
|
if i == at_row {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
let mut r: Vec<i32> = Vec::new();
|
let mut r: Vec<f32> = Vec::new();
|
||||||
for j in 0..self.data[i].len() {
|
for j in 0..self.data[i].len() {
|
||||||
if j == at_index {
|
if j == at_index {
|
||||||
continue;
|
continue;
|
||||||
|
|
@ -154,12 +193,12 @@ impl Matrix {
|
||||||
impl FromStr for Matrix {
|
impl FromStr for Matrix {
|
||||||
type Err = ParseMatrixError;
|
type Err = ParseMatrixError;
|
||||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||||
let mut d: Vec<Vec<i32>> = Vec::new();
|
let mut d: Vec<Vec<f32>> = Vec::new();
|
||||||
let rows_iter = s.split('\n');
|
let rows_iter = s.split('\n');
|
||||||
for txt in rows_iter {
|
for txt in rows_iter {
|
||||||
let mut r: Vec<i32> = Vec::new();
|
let mut r: Vec<f32> = Vec::new();
|
||||||
for ch in txt.split(',') {
|
for ch in txt.split(',') {
|
||||||
let parsed = match i32::from_str(ch) {
|
let parsed = match f32::from_str(ch) {
|
||||||
Ok(n) => Ok(n),
|
Ok(n) => Ok(n),
|
||||||
Err(_e) => Err(ParseMatrixError),
|
Err(_e) => Err(ParseMatrixError),
|
||||||
};
|
};
|
||||||
|
|
@ -218,12 +257,12 @@ impl<'a, 'b> Sub<&'b Matrix> for &'a Matrix {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
impl<'a> Mul<&'a Matrix> for i32 {
|
impl<'a> Mul<&'a Matrix> for f32 {
|
||||||
type Output = Matrix;
|
type Output = Matrix;
|
||||||
fn mul(self, rhs: &'a Matrix) -> Self::Output {
|
fn mul(self, rhs: &'a Matrix) -> Self::Output {
|
||||||
let mut d: Vec<Vec<i32>> = Vec::new();
|
let mut d: Vec<Vec<f32>> = Vec::new();
|
||||||
for r in &rhs.data {
|
for r in &rhs.data {
|
||||||
let mut nr: Vec<i32> = Vec::new();
|
let mut nr: Vec<f32> = Vec::new();
|
||||||
for v in r {
|
for v in r {
|
||||||
nr.push(self * v);
|
nr.push(self * v);
|
||||||
}
|
}
|
||||||
|
|
@ -235,21 +274,21 @@ impl<'a> Mul<&'a Matrix> for i32 {
|
||||||
impl<'a, 'b> Mul<&'b Matrix> for &'a Matrix {
|
impl<'a, 'b> Mul<&'b Matrix> for &'a Matrix {
|
||||||
type Output = Matrix;
|
type Output = Matrix;
|
||||||
fn mul(self, rhs: &'b Matrix) -> Self::Output {
|
fn mul(self, rhs: &'b Matrix) -> Self::Output {
|
||||||
fn reduce(lhs: &Matrix, rhs: &Matrix, at_r: usize, at_c: usize) -> i32 {
|
fn reduce(lhs: &Matrix, rhs: &Matrix, at_r: usize, at_c: usize) -> f32 {
|
||||||
let mut tmp = 0;
|
let mut tmp = 0.0;
|
||||||
for i in 0..lhs.ncols {
|
for i in 0..lhs.ncols {
|
||||||
tmp += lhs.get(at_r, i).unwrap() * rhs.get(i, at_c).unwrap();
|
tmp += lhs.get(at_r, i).unwrap() * rhs.get(i, at_c).unwrap();
|
||||||
}
|
}
|
||||||
tmp
|
tmp
|
||||||
}
|
}
|
||||||
let mut d: Vec<Vec<i32>> = Vec::new();
|
let mut d: Vec<Vec<f32>> = Vec::new();
|
||||||
if self.ncols != rhs.nrows {
|
if self.ncols != rhs.nrows {
|
||||||
println!("LHS: \n{}RHS: \n{}", self, rhs);
|
println!("LHS: \n{}RHS: \n{}", self, rhs);
|
||||||
println!("LHS nrows: {} ;; RHS ncols: {}", self.nrows, rhs.ncols);
|
println!("LHS nrows: {} ;; RHS ncols: {}", self.nrows, rhs.ncols);
|
||||||
panic!()
|
panic!()
|
||||||
}
|
}
|
||||||
for i in 0..self.nrows {
|
for i in 0..self.nrows {
|
||||||
let mut r: Vec<i32> = Vec::new();
|
let mut r: Vec<f32> = Vec::new();
|
||||||
for j in 0..rhs.ncols {
|
for j in 0..rhs.ncols {
|
||||||
r.push(reduce(self, rhs, i, j));
|
r.push(reduce(self, rhs, i, j));
|
||||||
}
|
}
|
||||||
|
|
@ -259,8 +298,8 @@ impl<'a, 'b> Mul<&'b Matrix> for &'a Matrix {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<Vec<i32>> for Matrix {
|
impl From<Vec<f32>> for Matrix {
|
||||||
fn from(value: Vec<i32>) -> Self {
|
fn from(value: Vec<f32>) -> Self {
|
||||||
Matrix {
|
Matrix {
|
||||||
nrows: value.len(),
|
nrows: value.len(),
|
||||||
ncols: 1,
|
ncols: 1,
|
||||||
|
|
|
||||||
|
|
@ -2,23 +2,67 @@ use std::str::FromStr;
|
||||||
|
|
||||||
use crate::{error::ParseMatrixError, matrix::Matrix, MatrixMath};
|
use crate::{error::ParseMatrixError, matrix::Matrix, MatrixMath};
|
||||||
|
|
||||||
|
enum TestCaseType {
|
||||||
|
Add,
|
||||||
|
Mul,
|
||||||
|
Inv,
|
||||||
|
CmpErr,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct TestCase {
|
||||||
|
test_type: TestCaseType,
|
||||||
|
test_data: Vec<Matrix>,
|
||||||
|
}
|
||||||
|
fn build_add_test_cases() -> Vec<TestCase> {
|
||||||
|
let mut v = vec![];
|
||||||
|
let from_strs = vec![
|
||||||
|
"1,2,3\n4,5,6\n7,8,9",
|
||||||
|
"1,1,1\n1,1,1\n1,1,1",
|
||||||
|
"2,3,4\n5,6,7\n8,9,10",
|
||||||
|
|
||||||
|
|
||||||
|
"1,1,1\n1,1,1\n1,1,1",
|
||||||
|
"0,0,0\n0,0,0\n0,0,0",
|
||||||
|
"1,1,1\n1,1,1\n1,1,1",
|
||||||
|
];
|
||||||
|
let mut i = 0;
|
||||||
|
while i < from_strs.len() {
|
||||||
|
let m1 = Matrix::from_str(from_strs[i]).unwrap();
|
||||||
|
let m2 = Matrix::from_str(from_strs[i+1]).unwrap();
|
||||||
|
let mr = Matrix::from_str(from_strs[i+2]).unwrap();
|
||||||
|
v.push(TestCase {
|
||||||
|
test_type: TestCaseType::Add,
|
||||||
|
test_data: vec![m1, m2, mr],
|
||||||
|
});
|
||||||
|
i += 3;
|
||||||
|
}
|
||||||
|
v
|
||||||
|
}
|
||||||
#[test]
|
#[test]
|
||||||
pub fn test_matrix_add() -> Result<(), ParseMatrixError> {
|
pub fn test_matrix_add() -> Result<(), ParseMatrixError> {
|
||||||
let m1 = Matrix::from_str("1,2,3\n4,5,6\n7,8,9")?;
|
let cases = build_add_test_cases();
|
||||||
let m2 = Matrix::from_str("1,1,1\n1,1,1\n1,1,1")?;
|
for case in cases {
|
||||||
let t = Matrix::from_str("2,3,4\n5,6,7\n8,9,10")?;
|
assert_eq!(&case.test_data[0] + &case.test_data[1], case.test_data[2]);
|
||||||
assert_eq!(&m1 + &m2, t);
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
#[test]
|
#[test]
|
||||||
pub fn test_matrix_determinate() -> Result<(), ParseMatrixError> {
|
pub fn test_matrix_determinate() -> Result<(), ParseMatrixError> {
|
||||||
let m = Matrix::from_str("3,4\n5,6")?;
|
let m = Matrix::from_str("3,4\n5,6")?;
|
||||||
let det = 3 * 6 - 4 * 5;
|
let det = 3.0 * 6.0 - 4.0 * 5.0;
|
||||||
assert_eq!(m.determinant(), det);
|
assert_eq!(m.determinant(), det);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
#[test]
|
#[test]
|
||||||
pub fn test_matrix_transposition() -> Result<(), ParseMatrixError> {
|
pub fn test_matrix_inverse_on_singular() -> Result<(), ()> {
|
||||||
|
let m = Matrix::new(vec![vec![1.0,2.0,3.0], vec![4.0,5.0,6.0], vec![7.0,8.0,9.0]]);
|
||||||
|
match m.inverse() {
|
||||||
|
Some(_inverse) => Err(()),
|
||||||
|
None => Ok(()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#[test]
|
||||||
|
pub fn test_matrix_transpose() -> Result<(), ParseMatrixError> {
|
||||||
let m = Matrix::from_str("1,2,3\n4,5,6\n7,8,9")?;
|
let m = Matrix::from_str("1,2,3\n4,5,6\n7,8,9")?;
|
||||||
let t = Matrix::from_str("1,4,7\n2,5,8\n3,6,9")?;
|
let t = Matrix::from_str("1,4,7\n2,5,8\n3,6,9")?;
|
||||||
assert_eq!(m.transpose(), t);
|
assert_eq!(m.transpose(), t);
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ use crate::{matrix::Matrix, error::ParseMatrixError};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
pub fn test_matrix_init_from_string() -> Result<(), ParseMatrixError> {
|
pub fn test_matrix_init_from_string() -> Result<(), ParseMatrixError> {
|
||||||
let data_target = vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]];
|
let data_target = vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0], vec![7.0, 8.0, 9.0]];
|
||||||
let target = Matrix::new(data_target);
|
let target = Matrix::new(data_target);
|
||||||
let test = Matrix::from_str("1,2,3\n4,5,6\n7,8,9")?;
|
let test = Matrix::from_str("1,2,3\n4,5,6\n7,8,9")?;
|
||||||
assert_eq!(target, test);
|
assert_eq!(target, test);
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue