feat(matrix_op): add multiplication

matrix_ops: add matrix multiply capability
chore(fmt): flatten crate directory
tests(ops): add tests for multiplication
This commit is contained in:
Zhongheng Liu 2025-01-23 19:06:49 +02:00
commit 333c6c7281
Signed by: steven
GPG key ID: 805A28B071DAD84B
7 changed files with 70 additions and 40 deletions

View file

@ -9,20 +9,23 @@
//!
//! Examples:
//! ```
//! ...
//! use matrix::Matrix;
//! let m = Matrix::from_str("1,2,3\n4,5,6\n7,8,9");
//! use std::str::FromStr;
//! let m = Matrix::from_str("1,2,3\n4,5,6\n7,8,9").expect("Expected this to work");
//! println!("Matrix string formatting:\n{}", m);
//! println!("Evaluate determinant of matrix: {}", m.determinant());
//! println!("Transpose of matrix m:\n{}", m.transpose());
//! ...
//!
//! ```
pub mod types;
mod matrix;
pub mod error;
#[cfg(test)]
mod tests;
pub use matrix::Matrix;
pub fn test() {
println!("Testing code here");
}

View file

@ -3,16 +3,17 @@
//! Example usage - addition of two matrices:
//! ```
//! use matrix::Matrix;
//! let m1 = Matrix::from_str("1,1,1\n1,1,1\n1,1,1");
//! let m2 = Matrix::from_str("2,2,2\n2,2,2\n2,2,2");
//! println!("Sum of m1 + m2: \n{}", m1 + m2);
//! use std::str::FromStr;
//! let m1 = Matrix::from_str("1,2\n3,4").expect("Expect parse correct");
//! let m2 = Matrix::from_str("1,1\n1,1").expect("Expect parse correct");
//! let m_add = &m1 + &m2;
//! println!("m1 + m2 =\n{}", m_add);
//! ```
//!
//! TODO:: Create matrix multiplication method
use std::{fmt::Display, ops::{Add, Mul, Sub}, str::FromStr};
use super::matrix_err::{MatrixSetValueError, ParseMatrixError};
use crate::error::{MatrixSetValueError, ParseMatrixError};
#[derive(Debug, PartialEq, Eq)]
pub struct Matrix {
/// Number of rows in matrix.
@ -95,7 +96,27 @@ impl<'a, 'b> Sub<&'b Matrix> for &'a Matrix {
impl<'a, 'b> Mul<&'b Matrix> for &'a Matrix {
type Output = Matrix;
fn mul(self, rhs: &'b Matrix) -> Self::Output {
todo!()
fn reduce(lhs: &Matrix, rhs: &Matrix, at_r: usize, at_c: usize) -> i32 {
let mut tmp = 0;
for i in 0..lhs.ncols {
tmp += lhs.get(at_r, i).unwrap() * rhs.get(i, at_c).unwrap();
}
tmp
}
let mut d: Vec<Vec<i32>> = Vec::new();
if self.ncols != rhs.nrows {
println!("LHS: \n{}RHS: \n{}", self, rhs);
println!("LHS nrows: {} ;; RHS ncols: {}", self.nrows, rhs.ncols);
panic!()
}
for i in 0..self.nrows {
let mut r: Vec<i32> = Vec::new();
for j in 0..rhs.ncols {
r.push(reduce(self, rhs, i, j));
}
d.push(r);
}
Matrix::new(d)
}
}

View file

@ -1,2 +1,3 @@
#[cfg(test)]
pub mod matrix_test;
mod matrix_test_parse;
mod matrix_test_ops;

View file

@ -1,19 +1,7 @@
use std::str::FromStr;
use crate::types::{matrix::Matrix, matrix_err::ParseMatrixError};
use crate::{matrix::Matrix, error::ParseMatrixError};
#[test]
pub fn test_matrix_init_from_string() -> Result<(), ParseMatrixError> {
let data_target = vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]];
let target = Matrix {
nrows: 3,
ncols: 3,
data: data_target,
};
let test = Matrix::from_str("1,2,3\n4,5,6\n7,8,9")?;
assert_eq!(target, test);
Ok(())
}
#[test]
pub fn test_matrix_add() -> Result<(), ParseMatrixError> {
let m1 = Matrix::from_str("1,2,3\n4,5,6\n7,8,9")?;
@ -37,13 +25,16 @@ pub fn test_matrix_transposition() -> Result<(), ParseMatrixError> {
Ok(())
}
#[test]
pub fn test_matrix_parse_malformed() -> () {
let malformed = "1,23,\n,567,\n\n5";
let m = Matrix::from_str(malformed);
match m {
Ok(_) => panic!("This malformed matrix string should not have succeeded"),
Err(_) => (),
}
pub fn test_matrix_mul() -> Result<(), ParseMatrixError> {
let m1 = Matrix::from_str("1,2\n1,2")?;
let m2 = Matrix::from_str("1,3\n2,4")?;
let m3 = Matrix::from_str("1,2\n3,4")?;
let m4 = Matrix::from_str("1\n2")?;
let t1 = Matrix::from_str("5,11\n5,11")?;
let t2 = Matrix::from_str("5\n11")?;
assert_eq!(&m1 * &m2, t1);
assert_eq!(&m3 * &m4, t2);
Ok(())
}
#[test]
#[should_panic]

View file

@ -0,0 +1,22 @@
use std::str::FromStr;
use crate::{matrix::Matrix, error::ParseMatrixError};
#[test]
pub fn test_matrix_init_from_string() -> Result<(), ParseMatrixError> {
let data_target = vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]];
let target = Matrix::new(data_target);
let test = Matrix::from_str("1,2,3\n4,5,6\n7,8,9")?;
assert_eq!(target, test);
Ok(())
}
#[test]
pub fn test_matrix_parse_malformed() -> () {
let malformed = "1,23,\n,567,\n\n5";
let m = Matrix::from_str(malformed);
match m {
Ok(_) => panic!("This malformed matrix string should not have succeeded"),
Err(_) => (),
}
}

View file

@ -1,8 +0,0 @@
//! Matrix-related type definitions
//!
//! Includes modules:
//! - Matrix
//! - Matrix parse and arithmetic errors
pub mod matrix;
pub mod matrix_err;