- Sort Score
- Result 10 results
- Languages All
Results 1 - 10 of 27 for broadcastable (0.21 sec)
-
tensorflow/compiler/mlir/tensorflow/ir/tf_arith_ops_folder.h
// Scalar identity is broadcastable to any operand shape, we only need to // check that operand has the same shape as a result. bool scalar_identity = identity_ty.hasRank() && identity_ty.getRank() == 0; if (scalar_identity) return operand_ty == result_ty; // If identity is not a scalar, we must verify that identity shape is // statically known to be broadcastable to the operand shape and the operand
Registered: Sun Jun 16 05:45:23 UTC 2024 - Last Modified: Thu Apr 25 16:01:03 UTC 2024 - 5.3K bytes - Viewed (0) -
tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.cc
return failure(); } } // Ensure that batch shapes are broadcastable. tensorflow::MatMulBCast bcast( absl::InlinedVector<int64_t, 4>(lhs_shape.begin(), lhs_shape.end()), absl::InlinedVector<int64_t, 4>(rhs_shape.begin(), rhs_shape.end())); if (!bcast.IsValid()) { // Input batch dimensions must be broadcastable return failure(); }
Registered: Sun Jun 16 05:45:23 UTC 2024 - Last Modified: Thu Apr 25 16:01:03 UTC 2024 - 11.6K bytes - Viewed (0) -
tensorflow/compiler/mlir/lite/utils/arithmetic_count_util.h
#include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project namespace mlir { namespace TFL { // For add/mul/div/sub and other broadcastable ops. class ArithmeticCountUtilHelper { public: static bool GetFirstOutputCount(mlir::Operation* op, int64_t* count) { auto output = op->getResult(0); auto output_type =
Registered: Sun Jun 16 05:45:23 UTC 2024 - Last Modified: Thu Apr 25 16:01:03 UTC 2024 - 3.1K bytes - Viewed (0) -
tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
} // Check if alpha is broadcastable for (int i = 0; i < alpha_type.getRank(); i++) { if (alpha_type.getDimSize(i) != input_type.getDimSize(i + 1) && alpha_type.getDimSize(i) != 1) { return op.emitOpError( llvm::formatv("'alpha' is not broadcastable at dimension {0}.", i)); } } }
Registered: Sun Jun 16 05:45:23 UTC 2024 - Last Modified: Thu May 02 09:41:17 UTC 2024 - 169.2K bytes - Viewed (0) -
tensorflow/compiler/mlir/tensorflow/transforms/fold_broadcast.cc
if (shape_x.size() < 2 || shape_y.size() < 2) { return false; } // Checks outer dimensions (i.e., the dimensions higher than 2D) are // broadcastable. If true, then get the broadcasted shape for outer // dimension. if (!OpTrait::util::getBroadcastedShape( shape_x.drop_back(2), shape_y.drop_back(2), result_shape)) { return false;
Registered: Sun Jun 16 05:45:23 UTC 2024 - Last Modified: Thu Apr 25 16:01:03 UTC 2024 - 7.9K bytes - Viewed (0) -
tensorflow/compiler/mlir/lite/utils/validators.cc
}); } bool IsBroadcastableElementsAttrs(mlir::TypedAttr a, mlir::TypedAttr b) { // This would return false if we had unranked tensors (where they should // probably be considered as broadcastable), but given we are working with // attributes here that shouldn't be an issue, return OpTrait::util::getBroadcastedType(a.getType(), b.getType()) != Type(); }
Registered: Sun Jun 16 05:45:23 UTC 2024 - Last Modified: Thu Apr 25 16:01:03 UTC 2024 - 5.2K bytes - Viewed (0) -
tensorflow/compiler/mlir/lite/stablehlo/transforms/fuse_convolution_pass.cc
if (!result_type) { return rewriter.notifyMatchFailure(mul_op, [&](::mlir::Diagnostic &diag) { diag << "entities 'filter, multiplier' failed to satisfy constraint: " "non-broadcastable operands"; }); } filter_value = filter.getValue(); mul_value = multiplier.getValue(); // In MHLO, Conv filter is in HWIO format, Depthwise conv filter is in HW1O
Registered: Sun Jun 16 05:45:23 UTC 2024 - Last Modified: Thu Feb 22 22:21:19 UTC 2024 - 8.3K bytes - Viewed (0) -
tensorflow/compiler/mlir/tensorflow/ir/tf_ops_tensor_helper.cc
auto result_type = OpTrait::util::getBroadcastedType(x.getType(), y.getType()); if (!result_type) { if (incompatible_shape_error.getValue()) { mlir::emitError(loc, "non-broadcastable operands"); } else { return UnrankedTensorType::get(builder->getI1Type()); } } auto ranked_type = mlir::dyn_cast<RankedTensorType>(result_type);
Registered: Sun Jun 16 05:45:23 UTC 2024 - Last Modified: Thu Apr 25 16:01:03 UTC 2024 - 6.7K bytes - Viewed (0) -
tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.td
"$1.getType().cast<ShapedType>().hasRank() && " "$0.getType().cast<ShapedType>().getShape() == $1.getType().cast<ShapedType>().getShape()">, "Checks if the shapes of tensors are same.">; // Make the 1D value $0 broadcastable with the shape of $1. def MakeOneDimValueBroadcastable : NativeCodeCall< "MakeOneDimValueBroadcastable($_builder, $_loc, $0, $1.getType().cast<ShapedType>())">;
Registered: Sun Jun 16 05:45:23 UTC 2024 - Last Modified: Wed Feb 14 03:24:59 UTC 2024 - 8.4K bytes - Viewed (0) -
tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-binary-elementwise.mlir
func.return %1: tensor<2xi32> } // CHECK-LABEL: func @broadcast_add // TODO(laurenzo): Change this to a (5 + 2x1) shaped add to make the check // patterns unambiguous and more interesting (once broadcastable trait is // fixed upstream). func.func @broadcast_add(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> {
Registered: Sun Jun 16 05:45:23 UTC 2024 - Last Modified: Sat Apr 06 15:32:52 UTC 2024 - 18.4K bytes - Viewed (0)